Part 2 (Starting Pytorch)
This section will walk you through creating a basic Python stage and initializing Pytorch within it.
First, let’s create the stage at ${APP_ROOT}/tutorial/detection_stage.py
. We’ll
populate it with a skeleton that looks like:
from torchvision.models import detection
import numpy as np
import torch
import cv2
import pypipeline
import pyserialize
from ark_image_messages import *
class ImageDetectorStage(pypipeline.Stage):
def __init__(self):
pass
def initialize(self, interface):
# Initialize CUDA here, requiring the GPU.
self.device = torch.device('cuda')
# Next, we'll use a pre-trained model that comes with Pytorch.
self.model = detection.retinanet_resnet50_fpn(
pretrained=True,
progress=True,
pretrained_backbone=True)
self.model.to(self.device)
self.model.eval()
# Add a publisher for detections (images, for debugging)
self.image_publisher = interface.add_publisher(Image, "/detections/image")
# Add a subscriber for images to run the detection on (using a well-known channel name).
interface.add_subscriber(Image, "/realsense/color/compressed_image", self.handle_image)
def shutdown(self):
pass
def handle_image(self, image):
pass
# Register the stage with the pipeline
pypipeline.connect_and_execute_stage(ImageDetectorStage())
Read the comments to see more detail. At a high level, the stage will create a torch device, download a preloaded model, and setup some publishers/subscribers. Right now, it doesn’t do anything with the actual images.
The last line is necessary to actually run your stage and communicate with the main Ark pipeline.
Before we can use this stage, we need to setup our pipeline, in Step 3.