Create a Model Bundle

A ModelBundle consists of a trained model as well as the surrounding preprocessing and postprocessing code.

Specifically, a model bundle consists of two Python objects:
either a model or load_model_fn;

load_predict_fn returns a function predict_fn that takes in one argument representing model input, and outputs one argument representing model output.

Typically, a model would be a Pytorch nn.Module or Tensorflow model, but can also be any arbitrary Python code.

The predict_fn is a higher order function that contains how the model object will be used. This includes how the model object is called, and any pre- and post-processing needed.

Here is an example of a model bundle for an object detection model:
In the load_predict_fn, we split out the function into pre- and post- processing steps, and we call those functions in the predict function, which we return. Note that the return type must adhere to Nucleus annotation formatting.

def load_model_fn():
  model = torch.load('')
  return model

 # This returns predict at the end
def load_predict_fn(model):
 if torch.cuda.is_available():
   device = torch.device('cuda')
   device = torch.device('cpu')

 def preprocess(img_bytes):
   # Transform the image bytes -> PIL image -> PyTorch tensor
   img =
   return T.ToTensor()(img)
 def predict(img_bytes):
  # Run the model input on the PyTorch model, this is the function that is returned
   tensor = preprocess(img_bytes)
   with torch.no_grad():
     model_output = model([])
   return postprocess(model_output)

 def postprocess(model_output):
   # Transform the model's output (PyTorch tensor) into JSON that gets
   # ingested by Nucleus, this needs to conform to the Nucleus annotation format.
   raw_boxes = model_output[0]['boxes'].cpu().numpy() # in x1, y1, x2, y2 format
   raw_labels = model_output[0]['labels'].cpu().numpy()
   raw_scores = model_output[0]['scores'].cpu().numpy()
   box_objects = []
   for box, label, score in zip(raw_boxes, raw_labels, raw_scores):
         'geometry': {
             'x': float(box[0]),
             'y': float(box[1]),
             'width': float(box[2] - box[0]),
             'height': float(box[3] - box[1]),
         'type': 'box',
         'label': str(label),
         'confidence': float(score)
   return box_objects

 return predict