This page documents a template for serving torch models and communicating with them with GRPC. I believe the content available online mostly focuses on using HTTP for communicating with the served models. Therefore, this is a good place for someone who has a preference for faster inference with the served models using GRPC.
This page does not explain the theory behind TorchServe or GRPC. Instead, it sets the steps for serving the models and implementing a grpc client.
project_root_folder
|__torchserve_grpc
|__ model_store
|__digit_model.mar
|__model_weights
|__digitcnn_state_dict.pth
|__images
|__test.png
|__handler.py
|__model.py
|__inference.proto
|__management.proto
|__inference_pb2.py
|__inference_pb2_grpc.py
|__management_pb2.py
|__management_pb2_grpc.py
|__grpc_client.py
import torch, torchvision
class DigitCNN(torch.nn.Module):
def __init__(self):
super().__init__()
self.model = self._model_prep()
def _model_prep(self):
model = torchvision.models.mobilenet_v3_small(pretrained=True, progress=True, )
classification_layer = torch.nn.Sequential(
torch.nn.Linear(576, out_features=1024),
torch.nn.Hardswish(),
torch.nn.Dropout(p=0.5),
torch.nn.Linear(1024, 10),
torch.nn.Softmax()
)
model.classifier = classification_layer
for param in model.features[:11].parameters():
param.requires_grad = False
return model
def forward(self, x):
return self.model(x)
from ts.torch_handler.vision_handler import VisionHandler
import torch
from PIL import Image
from torchvision import transforms
import logging
import io
import base64
class CustomHandler(VisionHandler):
def __init__(self):
super(CustomHandler, self).__init__()
self.image_processing = transforms.Compose([
transforms.Resize((96, 96)),
transforms.ToTensor()])
def preprocess(self, data):
images = []
for row in data:
# Compat layer: normally the envelope should just return the data
# directly, but older versions of Torchserve didn't have envelope.
image = row.get("data") or row.get("body")
if isinstance(image, str):
# if the image is a string of bytesarray.
image = base64.b64decode(image)
# If the image is sent as bytesarray
if isinstance(image, (bytearray, bytes)):
image = Image.open(io.BytesIO(image)).convert("RGB")
image = self.image_processing(image)
logging.info(f"image shape after preprocess: {image.shape}")
else:
# if the image is a list
image = torch.FloatTensor(image)
images.append(image)
return torch.stack(images).to(self.device)
def postprocess(self, data):
logging.info("Inside Post Process")
logging.info(f"Outputs: {data}")
predictions = torch.argmax(data, axis=1) + 1
logging.info(predictions.tolist())
return predictions.tolist()
torch-model-archiver --model-name digit-model --export-path torchserve_grpc/model_store --version 1.0 --model-file torchserve_grpc/model.py --serialized-file torchserve_grpc/model_weights/digitcnn_state_dict.pth --handler torchserve_grpc/handler.py --force
torchserve --start --model-store torchserve_grpc/model_store --models digitmodel=digit-model.mar --no-config-snapshot