This example is only compatible with CLI v1.20 and later. Should you be making use of an older version of the CLI, please run pip install --upgrade cerebrium to upgrade it to the latest version.

This tutorial shows you how to generate high-quality images using the SDXL refiner model from Stability AI, available on Hugging Face.

To see the final implementation, you can view it here

Basic Setup

Developing models with Cerebrium is similar to developing on a virtual machine or Google Colab, making conversion straightforward. Make sure you have the Cerebrium package installed and are logged in. If not, check our docs here.

First, create your project:

cerebrium init 2-sdxl-refiner

Configure your compute and environment settings in cerebrium.toml:


[cerebrium.deployment]
name = "3-sdxl-refiner"
python_version = "3.10"
include = ["./*", "main.py", "cerebrium.toml"]
exclude = ["./.*", "./__*"]

[cerebrium.hardware]
region = "us-east-1"
provider = "aws"
compute = "AMPERE_A10"
cpu = 2
memory = 16.0
gpu_count = 1

[cerebrium.scaling]
min_replicas = 0
max_replicas = 5
cooldown = 60

[cerebrium.dependencies.pip]
accelerate = "latest"
transformers = ">=4.35.0"
safetensors = "latest"
opencv-python = "latest"
diffusers = "latest"

[cerebrium.dependencies.conda]

[cerebrium.dependencies.apt]
ffmpeg = "latest"

Create a main.py file for our Python code. This simple implementation can be done in a single file. First, let’s define our request object:

from typing import Optional
from pydantic import BaseModel
import torch
from diffusers import StableDiffusionXLImg2ImgPipeline
from diffusers.utils import load_image
import io
import base64

class Item(BaseModel):
    prompt: str
    url: str
    negative_prompt: Optional[str]
    conditioning_scale: float
    height: int
    width: int
    num_inference_steps: int
    guidance_scale: float
    num_images_per_prompt: int

We import the required Python libraries and use Pydantic for data validation. The prompt and url parameters are required, while all others are optional. Missing required parameters will trigger an automatic error message.

Instantiate model

We load the SDXL model outside the predict function since it only needs to be loaded once at startup. While the model downloads during initial deployment, it’s automatically cached in persistent storage for subsequent use.

pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-refiner-1.0", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
)
pipe = pipe.to("cuda")

Predict Function

The predict function takes parameters from the request and passes them to the SDXL model to generate images. We convert images to base64 for direct JSON-serializable responses instead of writing to S3.

def predict(prompt, url, negative_prompt=None, conditioning_scale=0.5, height=512, width=512, num_inference_steps=20,
            guidance_scale=7.5, num_images_per_prompt=1):
    item = Item(
        prompt=prompt,
        url=url,
        negative_prompt=negative_prompt,
        conditioning_scale=conditioning_scale,
        height=height,
        width=width,
        num_inference_steps=num_inference_steps,
        guidance_scale=guidance_scale,
        num_images_per_prompt=num_images_per_prompt
    )

    init_image = load_image(item.url).convert("RGB")
    images = pipe(
        item.prompt,
        negative_prompt=item.negative_prompt,
        controlnet_conditioning_scale=item.conditioning_scale,
        height=item.height,
        width=item.width,
        num_inference_steps=item.num_inference_steps,
        guidance_scale=item.guidance_scale,
        num_images_per_prompt=item.num_images_per_prompt,
        image=init_image
    ).images

    finished_images = []
    for image in images:
        buffered = io.BytesIO()
        image.save(buffered, format="PNG")
        finished_images.append(base64.b64encode(buffered.getvalue()).decode("utf-8"))

    return {"images": finished_images}

Deploy

Deploy the model using this command:

cerebrium deploy

After deployment, make this request:

curl --location 'https://api.cortex.cerebrium.ai/v4/p-<YOUR PROJECT ID>/3-sdxl-refiner/predict' \
--header 'Content-Type: application/json' \
--header 'Authorization: Bearer <YOUR TOKEN HERE>' \
--data '{
    "url": "https://huggingface.co/datasets/patrickvonplaten/images/resolve/main/aa_xl/000000009.png",
    "prompt": "a photo of an astronaut riding a horse on mars"
}''

The endpoint returns results in this format:

{
    "run_id": "Gd2fLvweh1sHpdEQd4XnxYRvtGmghFxSg2rpbchK7wWAFeso9-sOVg==",
    "message": "Finished inference request with run_id: `Gd2fLvweh1sHpdEQd4XnxYRvtGmghFxSg2rpbchK7wWAFeso9-sOVg==`",
    "result": {
        "images": [
            <BASE64_ENCODED_STRING>
        ]
    },
    "status_code": 200,
    "run_time_ms": 4388.460874557495
}

Example output: