Introduction

In this tutorial, I will show you how you can migrate your workloads from Replicate to Cerebrium in less than 5 minutes!

As an example, we will be migrating the model SDXL-Lightning-4step from ByteDance. You can find the link to it on replicate here:

It is best to look at the code in the GitHub repo and follow along as we migrate it.

To start, let us create our cerebrium project

cerebrium init cog-migration-sdxl

Now Cerebrium and Replicate have a common setup in that they both have a setup file. cog.yaml and cerebrium.toml for Replicate and Cerebrium respectively.

Looking at the cog.yaml, we need to add/change the following in our cerebrium.toml

[cerebrium.deployment]
name = "cog-migration-sdxl"
python_version = "3.11"
include = "[./*, main.py, cerebrium.toml]"
exclude = "[./example_exclude]"
docker_base_image_url = "nvidia/cuda:12.1.1-cudnn8-runtime-ubuntu22.04"
shell_commands = [
    "curl -o /usr/local/bin/pget -L 'https://github.com/replicate/pget/releases/download/v0.6.2/pget_linux_x86_64' && chmod +x /usr/local/bin/pget"
]

[cerebrium.hardware]
region = "us-east-1"
provider = "aws"
compute = "AMPERE_A10"
cpu = 2
memory = 12.0
gpu_count = 1

[cerebrium.dependencies.pip]
"accelerate" = "latest"
"diffusers" = "latest"
"torch" = "==2.0.1"
"torchvision" = "==0.15.2"
"transformers" = "latest"

[cerebrium.dependencies.apt]
"curl" = "latest"

From the above we do the following:

  • Since we need a GPU, we need to use one of the base images that come from Nvidia that has the CUDA libraries installed. We use the Cuda 12 image. You can see other images here
  • Depending on the type of CPU/GPU you need, you can update the hardware settings to run your application. You can see the full list available here
  • We copy across the pip packages we need to install
  • Replicate uses pget to download model weights - we therefore need to download it to use it. We do this by installing curl and then adding the shell commands in our cerebrium.toml

Great now our setup is the same in terms of our hardware and environment.

Now the cog.yaml usually indicates the file that the endpoint calls - in this case, predict.py so let us inspect that file.

Cerebrium has a similar notion in that the main file that is called on our side is main.py

To start, I copy across all import statements and constant variables that have nothing to do with Replicate/Cog. In this case:

import os
import time
import torch
import subprocess
import numpy as np
from typing import List
from transformers import CLIPImageProcessor
from diffusers import (
    StableDiffusionXLPipeline,
    DDIMScheduler,
    DPMSolverMultistepScheduler,
    EulerAncestralDiscreteScheduler,
    EulerDiscreteScheduler,
    HeunDiscreteScheduler,
    PNDMScheduler,
    KDPM2AncestralDiscreteScheduler,
)
from diffusers.pipelines.stable_diffusion.safety_checker import (
    StableDiffusionSafetyChecker,
)

UNET = "sdxl_lightning_4step_unet.pth"
MODEL_BASE = "stabilityai/stable-diffusion-xl-base-1.0"
UNET_CACHE = "unet-cache"
BASE_CACHE = "checkpoints"
SAFETY_CACHE = "safety-cache"
FEATURE_EXTRACTOR = "feature-extractor"
MODEL_URL = "https://weights.replicate.delivery/default/sdxl-lightning/sdxl-1.0-base-lightning.tar"
SAFETY_URL = "https://weights.replicate.delivery/default/sdxl/safety-1.0.tar"
UNET_URL = "https://weights.replicate.delivery/default/comfy-ui/unet/sdxl_lightning_4step_unet.pth.tar"

class KarrasDPM:
    def from_config(config):
        return DPMSolverMultistepScheduler.from_config(config, use_karras_sigmas=True)


SCHEDULERS = {
    "DDIM": DDIMScheduler,
    "DPMSolverMultistep": DPMSolverMultistepScheduler,
    "HeunDiscrete": HeunDiscreteScheduler,
    "KarrasDPM": KarrasDPM,
    "K_EULER_ANCESTRAL": EulerAncestralDiscreteScheduler,
    "K_EULER": EulerDiscreteScheduler,
    "PNDM": PNDMScheduler,
    "DPM++2MSDE": KDPM2AncestralDiscreteScheduler,
}

Replicate makes use of classes for their syntax which we shy away from - we run whatever python code you give us and make each function an endpoint. Therefore, when you see a reference to self. remove it throughout the code

There is a folder in the repo called “feature-extractor” which we need to have in our repository. We could git clone the repo, however, its quite small, so I would just copy the contents of the folder and put it in your cerebrium project ie:

Folder Structure

The setup function on Replicate runs on each cold start (ie: each new instantiation of the application) and so we just define it as normal code that gets run at the top of our file. I put it right below my import statements above.

def download_weights(url, dest):
    start = time.time()
    print("downloading url: ", url)
    print("downloading to: ", dest)
    subprocess.check_call(["pget", "-x", url, dest], close_fds=False)
    print("downloading took: ", time.time() - start)

"""Load the model into memory to make running multiple predictions efficient"""
start = time.time()
print("Loading safety checker...")
if not os.path.exists(SAFETY_CACHE):
    download_weights(SAFETY_URL, SAFETY_CACHE)
print("Loading model")
if not os.path.exists(BASE_CACHE):
    download_weights(MODEL_URL, BASE_CACHE)
print("Loading Unet")
if not os.path.exists(UNET_CACHE):
    download_weights(UNET_URL, UNET_CACHE)
self.safety_checker = StableDiffusionSafetyChecker.from_pretrained(
    SAFETY_CACHE, torch_dtype=torch.float16
).to("cuda")
self.feature_extractor = CLIPImageProcessor.from_pretrained(FEATURE_EXTRACTOR)
print("Loading txt2img pipeline...")
self.pipe = StableDiffusionXLPipeline.from_pretrained(
    MODEL_BASE,
    torch_dtype=torch.float16,
    variant="fp16",
    cache_dir=BASE_CACHE,
    local_files_only=True,
).to("cuda")
unet_path = os.path.join(UNET_CACHE, UNET)
self.pipe.unet.load_state_dict(torch.load(unet_path, map_location="cuda"))
print("setup took: ", time.time() - start)

The code above downloads the model weights if they don’t exist and then instantiates the models. To persist files/data on Cerebrium, you need to store it on the path /persistent-storage. So we can update the following paths above:

UNET_CACHE = "/persistent-storage/unet-cache"
BASE_CACHE = "/persistent-storage/checkpoints"
SAFETY_CACHE = "/persistent-storage/safety-cache"

We can then copy the two other functions, run_safety_checker() and predict(). In Cerebrium, the parameters of a function is the json data it expects if you make a request to it. We can then define it as follows:

def run_safety_checker(image):
    safety_checker_input = feature_extractor(image, return_tensors="pt").to(
        "cuda"
    )
    np_image = [np.array(val) for val in image]
    image, has_nsfw_concept = safety_checker(
        images=np_image,
        clip_input=safety_checker_input.pixel_values.to(torch.float16),
    )
    return image, has_nsfw_concept

def predict(
    prompt: str = "A superhero smiling",
    negative_prompt: str = "worst quality, low quality",
    width: int = 1024,
    height: int = 1024,
    num_outputs: int = 1,
    scheduler: str = "K_EULER",
    num_inference_steps: int = 4,
    guidance_scale: float = 0,
    seed: int = None,
    disable_safety_checker: bool = False,
):
    """Run a single prediction on the model"""
    global pipe
    if seed is None:
        seed = int.from_bytes(os.urandom(4), "big")
    print(f"Using seed: {seed}")
    generator = torch.Generator("cuda").manual_seed(seed)

    # OOMs can leave vae in bad state
    if pipe.vae.dtype == torch.float32:
        pipe.vae.to(dtype=torch.float16)

    sdxl_kwargs = {}
    print(f"Prompt: {prompt}")
    sdxl_kwargs["width"] = width
    sdxl_kwargs["height"] = height

    pipe.scheduler = SCHEDULERS[scheduler].from_config(
        pipe.scheduler.config, timestep_spacing="trailing"
    )

    common_args = {
        "prompt": [prompt] * num_outputs,
        "negative_prompt": [negative_prompt] * num_outputs,
        "guidance_scale": guidance_scale,
        "generator": generator,
        "num_inference_steps": num_inference_steps,
    }

    output = pipe(**common_args, **sdxl_kwargs)

    if not disable_safety_checker:
        _, has_nsfw_content = run_safety_checker(output.images)


        output_paths = []
        for i, image in enumerate(output.images):
            if not disable_safety_checker:
                if has_nsfw_content[i]:
                    print(f"NSFW content detected in image {i}")
                    continue
            output_path = f"/tmp/out-{i}.png"
            image.save(output_path)
            output_paths.append(Path(output_path))

        if len(output_paths) == 0:
            raise Exception(
                "NSFW content detected. Try running it again, or try a different prompt."
            )

        return output_paths

The above returns a path to the generated images, but we would like to return it as a base64 encoded image so that users can render the image instantly. You are welcome to upload the images to a storage bucket to reference directly - its up to you.

from io import BytesIO
import base64

encoded_images = []
    for i, image in enumerate(output.images):
        if not disable_safety_checker:
            if has_nsfw_content[i]:
                print(f"NSFW content detected in image {i}")
                continue
        buffered = BytesIO()
        image.save(buffered, format="PNG")
        img_b64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
        encoded_images.append(img_b64)

    if len(encoded_images) == 0:
        raise Exception(
            "NSFW content detected. Try running it again, or try a different prompt."
        )

    return encoded_images

Now we can run cerebrium deploy. You should see your application build in under 90 seconds.

It should output the curl statement to run your application:

Curl Request

Make sure to replace the end of the URL with /predict (since that is the function we are calling) and send it the required JSON data. This is our result

{
    "run_id": "c6797f2e-333a-9e89-bafa-4dd0f4fbe22a",
    "result": ["iVBORw0KGgoAAAANSUhEUgAABAAAAAQACAIAAADwf7zUAA...."],
    "run_time_ms": 43623.4176158905
}

You should be all ready to go!

You can read further about some of the functionality Cerebrium has to offer