Segment Anything Example
In this tutorial, we will demonstrate how to get Meta’s ‘Segment Anything’ model up and running in order to recognize items a person is wearing - think about the shopping feature on Instagram. It identifies clothes and associates the link to these items.
To see the final implementation you can view it here
Create Cerebrium Account
Before building you need to set up a Cerebrium account. This is as simple as starting a new Project in Cerebrium and copying the API key. This will be used to authenticate all calls for this project.
Create a project
- Go to dashboard.cerebrium.ai
- Sign up or Login
- Navigate to the API Keys page
- You should see an API key with the source “Cerebrium”. Click the eye icon to display it. It will be in the format: c_api_key-xxx
Basic Setup
It is important to think of the way you develop models using Cerebrium should be identical to developing on a virtual machine or Google Colab.
To start we need to create a main.py file which will contain our main Python code. This is a relatively simple implementation, so we can do everything in 1 file.
I then go straight to the Meta Repo in order to see how I can run this code. I need to install a few dependencies. In order to do this, I need to create a requirements.txt file. with the following dependencies:
git+https://github.com/facebookresearch/segment-anything.git
opencv-python
pycocotools
I then define the following in my main.py file
from segment_anything import SamAutomaticMaskGenerator, sam_model_registry
import cv2
sam = sam_model_registry["default"](checkpoint="./sam_vit_h_4b8939.pth")
sam.to("cuda")
mask_generator = SamAutomaticMaskGenerator(
model=sam,
points_per_side=32,
pred_iou_thresh=0.96,
stability_score_thresh=0.92,
crop_n_layers=1,
crop_n_points_downscale_factor=2,
min_mask_region_area=100, # Requires open-cv to run post-processing
)
We define this at the top of the file for two reasons:
- Code loaded in the top of the file is automatically run when deployed, meaning the model can be downloaded and loaded into memory on the first deployment. The first deployment is the longest but thereafter it should be pretty quick.
- If this was in a function, it would not be called on deploy but on the first request. Also, it would be loaded from cache on every request. We do not want either of these two things.
With the SAM model, you are able to set parameters about how precise or how detailed you would like the model to perform. Since our use case is around detecting clothes on an image, we don’t want the segmentation to be that granular. You can look from Meta’s documentation the different parameters you can set.
In the documentation we need to select a model type and a checkpoint. Meta recommends three - We will just use the default. I then download the ViT-H SAM model checkpoint. You can simply add this checkpoint file to the same directory as your main.py and when you deploy your model, Cerebrium will automatically upload it.
I then would like the user to send in a few parameters for the model via API, and so I need to define what the request object would look like:
from pydantic import BaseModel
from typing import Optional
class Item(BaseModel):
image: Option[str]
file_url: Optional[str]
coordinates: list
Pydantic is a data validation library and BaseModel is where Cerebrium keeps some default parameters like “webhook_url” that allows a user to send in a webhook url, and we will call it when the job has finished processing - this is useful for long-running tasks. Do not worry about that functionality for this tutorial. The reason the user is sending in an image or a file url is giving the user a choice to send in a base64 encoded image or a publicly accessible file_url we can download.
Identifying and classifying objects
Next we would like to create three functions with the following functionality:
- Based on the image coordinates sent in, we would like to find the annotation (segment) the user is most likely referring to.
- Take the annotation of the image we have focused on from Meta’s model, create a new image and crop it to only be an image of the item we are focusing on. The reason we crop the image is to make the classification model perform more accurately.
- Pass this new image into an Object classification model, so we can be told what this image is.
##1.
def find_annotation_by_coordinates(annotations, x, y):
for ann in annotations:
bbox_x, bbox_y, bbox_w, bbox_h = ann['bbox']
if bbox_x <= x <= bbox_x + bbox_w and bbox_y <= y <= bbox_y + bbox_h:
return ann
return None
##2.
def create_image(ann):
m = ann['segmentation']
resized_original_image = cv2.resize(image, (m.shape[1], m.shape[0]))
mask = np.ones((m.shape[0], m.shape[1], 3), dtype=np.uint8) *255
mask[m] = resized_original_image[m] # Set the segmented area to white
x, y, w, h = ann['bbox']
cropped_image = mask[y:y+h, x:x+w]
return cv2.cvtColor(cropped_image, cv2.COLOR_RGB2BGR)
##3.
from transformers import ViTImageProcessor, ViTForImageClassification
processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224')
model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224')
def classify(image):
inputs = processor(images=image, return_tensors="pt")
outputs = model(**inputs)
logits = outputs.logits
# model predicts one of the 1000 ImageNet classes
predicted_class_idx = logits.argmax(-1).item()
return model.config.id2label[predicted_class_idx]
In the last function, we implement a Vision Transformer model from Google that helps us identify objects and returns the class mostly likely predicted. You will notice we don’t define the processor and model inside the function since we only want the loading of these models to happen once so define it near the top of your file under your sam code.
Putting it all together
We would then like to create a function that we would like to be called every time our user submits a request. This request will download/convert the image the user submitted, feed it into the model and return the items detected in the image. Our code looks as follows:
import requests
from PIL import Image
from io import BytesIO
import base64
def download_image(url):
if url:
r = requests.get(url)
else:
return ValueError("No valid url passed")
return Image.open(BytesIO(r.content))
def predict(item, run_id, logger):
if (not item.image and not item.file_url): return ""
if item.image:
image = v2.cvtColor(Image.open(BytesIO(base64.b64decode(item.image))), cv2.COLOR_BGR2RGB)
elif item.file_url:
image = v2.cvtColor(download_image(item.file_url), cv2.COLOR_BGR2RGB)
masks = mask_generator.generate(image)
selected_annotation = find_annotation_by_coordinates(masks, cursor_x, cursor_y)
if not selected_annotation:
return {"message": "No annotation found at the given coordinates."}
segmented_image = create_image(selected_annotation)
results = classify(segmented_image)
return {"result": result}
In the above code we do a few things:
- We import the various packages we need and add the necessary ones to our requirements.txt
- We create a function that will download the image from an url if the user supplied one.
- We run our “Segment Anything” model with the image the user sent through the request.
- Lastly, we responded with the converted images. All returns from your endpoint function must be of type list or JSON.
We can then deploy our model to an A10 instance with the following line of code
cerebrium-deploy -n segment-anything -h A10 -k <API_KEY>
After a few minutes, your model should be deployed and an endpoint should be returned. Let us create a CURL request to see the response
curl --location --request POST 'https://run.cerebrium.ai/v1/p-xxxxxx/segment-anything/predict' \
--header 'Authorization: dev-c_api_key-XXXXXXXXXXXX' \
--header 'Content-Type: application/json' \
--data-raw '{
"file_url": "https://cdn-www.thefashionspot.com/assets/uploads/gallery/the-off-duty-male-models-of-milan-mens-fashion-week/milano-m-moc-rf14-8533.jpg",
"cursor": [400,300]
}'
And here is the output of our model
{
"run_id": "526ca24d-c3ef-41ce-a29c-6c03e7e9585c",
"message": "success",
"result": {
"result": "academic gown, academic robe, judge's robe"
},
"run_time_ms": 13880.023717880249
}