Part 2 (Starting Pytorch)
This section will walk you through creating a basic Python stage and initializing Pytorch within it.
First, let’s create the stage at
populate it with a skeleton that looks like:
from torchvision.models import detection import numpy as np import torch import cv2 import pypipeline import pyserialize from ark_image_messages import * class ImageDetectorStage(pypipeline.Stage): def __init__(self): pass def initialize(self, interface): # Initialize CUDA here, requiring the GPU. self.device = torch.device('cuda') # Next, we'll use a pre-trained model that comes with Pytorch. self.model = detection.retinanet_resnet50_fpn( pretrained=True, progress=True, pretrained_backbone=True) self.model.to(self.device) self.model.eval() # Add a publisher for detections (images, for debugging) self.image_publisher = interface.add_publisher(Image, "/detections/image") # Add a subscriber for images to run the detection on (using a well-known channel name). interface.add_subscriber(Image, "/realsense/color/compressed_image", self.handle_image) def shutdown(self): pass def handle_image(self, image): pass # Register the stage with the pipeline pypipeline.connect_and_execute_stage(ImageDetectorStage())
Read the comments to see more detail. At a high level, the stage will create a torch device, download a preloaded model, and setup some publishers/subscribers. Right now, it doesn’t do anything with the actual images.
The last line is necessary to actually run your stage and communicate with the main Ark pipeline.
Before we can use this stage, we need to setup our pipeline, in Step 3.