This page documents a template for serving torch models and communicating with them with GRPC. I believe the content available online mostly focuses on using HTTP for communicating with the served models. Therefore, this is a good place for someone who has a preference for faster inference with the served models using GRPC.

This page does not explain the theory behind TorchServe or GRPC. Instead, it sets the steps for serving the models and implementing a grpc client.

Directory structure

project_root_folder
|__torchserve_grpc
	|__ model_store
		|__digit_model.mar
	|__model_weights
		|__digitcnn_state_dict.pth
	|__images
		|__test.png
	|__handler.py
	|__model.py
	|__inference.proto
	|__management.proto
	|__inference_pb2.py
	|__inference_pb2_grpc.py
	|__management_pb2.py
	|__management_pb2_grpc.py
  |__grpc_client.py

Model Definition file (model.py)

import torch, torchvision

class DigitCNN(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.model = self._model_prep()

    def _model_prep(self):  
        model = torchvision.models.mobilenet_v3_small(pretrained=True, progress=True, )
        classification_layer = torch.nn.Sequential(
            torch.nn.Linear(576, out_features=1024),
            torch.nn.Hardswish(),
            torch.nn.Dropout(p=0.5),
            torch.nn.Linear(1024, 10),
            torch.nn.Softmax()
        )
        model.classifier = classification_layer

        for param in model.features[:11].parameters():
            param.requires_grad = False
        return model

    def forward(self, x):
        return self.model(x)

Custom Handler (handler.py)

from ts.torch_handler.vision_handler import VisionHandler
import torch
from PIL import Image
from torchvision import transforms
import logging
import io
import base64

class CustomHandler(VisionHandler):
    def __init__(self):
        super(CustomHandler, self).__init__()
        self.image_processing = transforms.Compose([
            transforms.Resize((96, 96)),
            transforms.ToTensor()])
    
    def preprocess(self, data):
        images = []
        for row in data:
            # Compat layer: normally the envelope should just return the data
            # directly, but older versions of Torchserve didn't have envelope.
            image = row.get("data") or row.get("body")
            if isinstance(image, str):
                # if the image is a string of bytesarray.
                image = base64.b64decode(image)

            # If the image is sent as bytesarray
            if isinstance(image, (bytearray, bytes)):
                image = Image.open(io.BytesIO(image)).convert("RGB")
                image = self.image_processing(image)
                logging.info(f"image shape after preprocess: {image.shape}")
            else:
                # if the image is a list
                image = torch.FloatTensor(image)
            
            images.append(image)

        return torch.stack(images).to(self.device)
    
    def postprocess(self, data):
        logging.info("Inside Post Process")
        logging.info(f"Outputs: {data}")
        predictions = torch.argmax(data, axis=1) + 1
        logging.info(predictions.tolist())
        return predictions.tolist()

Archiving the model

torch-model-archiver --model-name digit-model --export-path torchserve_grpc/model_store --version 1.0 --model-file torchserve_grpc/model.py --serialized-file torchserve_grpc/model_weights/digitcnn_state_dict.pth --handler torchserve_grpc/handler.py --force

Serving the model

torchserve --start --model-store torchserve_grpc/model_store --models digitmodel=digit-model.mar --no-config-snapshot