In this tutorial, we will demonstrate how to get Meta’s ‘Segment Anything’ model up and running to recognize items a person is wearing - think about the shopping feature on Instagram. It identifies clothes and associates the link to these items.

To see the final implementation, you can view it here

Basic Setup

It is important to think of the way you develop models using Cerebrium should be identical to developing on a virtual machine or Google Colab - so converting this should be very easy! Please make sure you have the Cerebrium package installed and have logged in. If not, please take a look at our docs here

First we create our project:

cerebrium init segment-anything

It is important to think of the way you develop models using Cerebrium should be identical to developing on a virtual machine or Google Colab.

Our file will contain our main Python code. This is a relatively simple implementation, so we can do everything in 1 file. We would like a user to send in a link to a YouTube video with a question and return to them the answer as well as the time segment of where we got that response. So let us define our request object.

To start, we need to create a file which will contain our main Python code. This is a relatively simple implementation, so we can do everything in 1 file.

I then go straight to the Meta Repo to see how I can run this code.

We need certain Python packages to implement this project. Let’s add those to our [cerebrium.dependencies.pip] section of our cerebrium.toml file:

segment-anything = "git+|
opencv-python = "latest"
pycocotools = "latest"

I then define the following in my file

from segment_anything import SamAutomaticMaskGenerator, sam_model_registry
import cv2

print("Downloading file...")
if not os.path.exists("/persistent-storage/segment-anything/sam_vit_h_4b8939.pth"):
    response = requests.get("")
    with open("sam_vit_h_4b8939.pth", "wb") as f:
    print("Download complete")
    print("File already exists")
sam = sam_model_registry["default"](checkpoint="/persistent-storage/segment-anything/sam_vit_h_4b8939.pth")"cuda")
mask_generator = SamAutomaticMaskGenerator(
    min_mask_region_area=100,  # Requires open-cv to run post-processing

We define this at the top of the file for two reasons:

  1. Code loaded in the top of the file is automatically run when deployed, meaning the model can be downloaded and loaded into memory on the first deployment. The first deployment is the longest, but thereafter it should be pretty quick.
  2. If this was in a function, it would not be called on deploy but on the first request. Also, it would be loaded from cache on every request. We do not want either of these two things.

With the SAM model, you are able to set parameters about how precise or how detailed you would like the model to perform. Since our use case is around detecting clothes on an image, we don’t want the segmentation to be that granular. You can look from Meta’s documentation the different parameters you can set.

In the documentation, we need to select a model type and a checkpoint. Meta recommends three — We will just use the default. I then download the ViT-H SAM model checkpoint. You can add this checkpoint file to the same directory as your and when you deploy your model, Cerebrium will automatically upload it.

I then would like the user to send in a few parameters for the model via API, and so I need to define what the request object would look like:

from pydantic import BaseModel
from typing import Optional

class Item(BaseModel):
    image: Option[str]
    file_url: Optional[str]
    coordinates: list

Pydantic is a data validation library and BaseModel is where Cerebrium keeps some default parameters like “webhook_url” that allows a user to send in a webhook url. We will call it when the job has finished processing — this is useful for long-running tasks. Do not worry about that functionality for this tutorial. The reason the user is sending in an image or a file url is giving the user a choice to send in a base64 encoded image or a publicly accessible file_url we can download.

Identifying and classifying objects

Next, we would like to create three functions with the following functionality:

  1. Based on the image coordinates sent in, we would like to find the annotation (segment) the user is most likely referring to.
  2. Take the annotation of the image we have focused on from Meta’s model, create a new image and crop it to only be an image of the item we are focusing on. The reason we crop the image is to make the classification model perform more accurately.
  3. Pass this new image into an Object classification model, so we can be told what this image is.

def find_annotation_by_coordinates(annotations, x, y):
    for ann in annotations:
        bbox_x, bbox_y, bbox_w, bbox_h = ann['bbox']
        if bbox_x <= x <= bbox_x + bbox_w and bbox_y <= y <= bbox_y + bbox_h:
            return ann
    return None

def create_image(ann):
  m = ann['segmentation']
  resized_original_image = cv2.resize(image, (m.shape[1], m.shape[0]))
  mask = np.ones((m.shape[0], m.shape[1], 3), dtype=np.uint8) *255
  mask[m] = resized_original_image[m]  # Set the segmented area to white
  x, y, w, h = ann['bbox']
  cropped_image = mask[y:y+h, x:x+w]

  return cv2.cvtColor(cropped_image, cv2.COLOR_RGB2BGR)

from transformers import ViTImageProcessor, ViTForImageClassification
processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224')
model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224')

def classify(image):
    inputs = processor(images=image, return_tensors="pt")
    outputs = model(**inputs)
    logits = outputs.logits
    # model predicts one of the 1000 ImageNet classes
    predicted_class_idx = logits.argmax(-1).item()
    return model.config.id2label[predicted_class_idx]

In the last function, we implement a Vision Transformer model from Google that helps us identify objects and returns the class mostly likely predicted. You will notice we don’t define the processor and model inside the function since we only want the loading of these models to happen once so define it near the top of your file under your sam code.

Putting it all together

We would then like to create a function that we would like to be called every time our user submits a request. This request will download/convert the image the user submitted, feed it into the model and return the items detected in the image. Our code looks as follows:

import requests
from PIL import Image
from io import BytesIO
import base64

def download_image(url):
    if url:
        r = requests.get(url)
        return ValueError("No valid url passed")


def predict(item, run_id, logger):
    if (not item.image and not item.file_url): return ""

    if item.image:
        image = v2.cvtColor(, cv2.COLOR_BGR2RGB)
    elif item.file_url:
        image = v2.cvtColor(download_image(item.file_url), cv2.COLOR_BGR2RGB)

    masks = mask_generator.generate(image)
    selected_annotation = find_annotation_by_coordinates(masks, cursor_x, cursor_y)

    if not selected_annotation:
        return {"message": "No annotation found at the given coordinates."}
    segmented_image = create_image(selected_annotation)
    results = classify(segmented_image)

    return {"result": result}

In the above code we do a few things:

  1. We import the various packages we need and add the necessary ones to our requirements.txt
  2. We create a function that will download the image from an url if the user supplied one.
  3. We run our “Segment Anything” model with the image the user sent through the request.
  4. Lastly, we responded with the converted images. All returns from your endpoint function must be of type list or JSON.

We can then deploy our model to an AMPERE_A5000 instance with the following line of code

cerebrium deploy --name segment-anything --gpu AMPERE_A5000

After a few minutes, your model should be deployed and an endpoint should be returned. Let us create a CURL request to see the response

curl --location --request POST '' \
--header 'Authorization: public-XXXXXXXXXXXX' \
--header 'Content-Type: application/json' \
--data-raw '{
    "file_url": "",
    "cursor": [400,300]

And here is the output of our model

  "run_id": "526ca24d-c3ef-41ce-a29c-6c03e7e9585c",
  "message": "success",
  "result": {
    "result": "academic gown, academic robe, judge's robe"
  "run_time_ms": 13880.023717880249