In this tutorial, we’ll create and deploy a Gradio chat interface that connects to a Llama 8B language model using Cerebrium’s custom ASGI runtime. We’ll build a scalable architecture where the frontend runs on CPU instances while the model can be deployed separately on GPU instances for optimal resource utilization.

You can find the full codebase for deploying your gradio frontend, here.

Architecture Overview

Our application consists of two main components:

  1. A frontend interface running on CPU instances using FastAPI and Gradio
  2. A separate Llama model endpoint running on GPU instances (This falls out of the scope of this article, but don’t fret, you can find a comprehensive example for deploying Llama 8B with TensorRT, here)

This separation allows us to:

  • Always keep the frontend available while minimizing costs (CPU-only)
  • Scale our GPU-intensive model independently based on demand
  • Optimize resource allocation for different components

Prerequisites

Before starting, you’ll need:

  • A Cerebrium account (sign up, here)
  • The Cerebrium CLI installed: pip install --upgrade cerebrium
  • A Llama model endpoint (or other LLM API endpoint)

Basic Setup

First, create a new directory for your project and initialize it:

cerebrium init 34-gradio-interface

Next, let us add the following configuration to our cerebrium.toml file:

[cerebrium.deployment]
name = "33-gradio-interface"
python_version = "3.12"
disable_auth = true
include = ['./*', 'main.py', 'cerebrium.toml']
exclude = ['.*']

[cerebrium.runtime.custom]
entrypoint = ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8080"]
port = 8080
healthcheck_endpoint = "/health"

[cerebrium.hardware]
cpu = 2
memory = 4.0
compute = "CPU"

[cerebrium.scaling]
min_replicas = 0
max_replicas = 2
cooldown = 30
replica_concurrency = 10

[cerebrium.dependencies.pip]
gradio = "latest"
fastapi = "latest"
requests = "latest"
httpx = "latest"
uvicorn = "latest"
starlette = "latest"

This configuration does a few things. Let’s run through them step by step:

  • First, it disables the default JWT auth that are automatically placed on all Cerebrium endpoints. This is so that your Gradio interface can be accessible publicly.
  • Then, it sets the entrypoint for the ASGI server to run through uvicorn
  • it sets the default port that your application gets served through, to 8080;
  • Sets the health endpoint (Which we use to check if your application is reachable or not) to /health. This is served through our fast API application.
  • It sets the hardware configuration for the CPU instance that your application will run on.
  • It sets the scaling configuration for your application. This is the minimum and maximum number of replicas that your application can have. The cooldown is the time in seconds that Cerebrium waits before scaling up or down. The replica_concurrency is the number of requests that a single replica can handle at a time. For now, we set it to 10 - you can change to what suites your needs.
  • Finally, it sets the dependencies that your application needs to run. This includes Gradio, FastAPI, Requests, HTTPX, Uvicorn, and Starlette.

Now, let’s set up our main entrypoint file (main.py). To start, let’s create our FastAPI application:

# at the top of your main.py file
import requests
from typing import Optional, List

import httpx
from fastapi import FastAPI, Request
from starlette.responses import Response as StarletteResponse

app = FastAPI()

# Get the Gradio app URL (when running on Cerebrium)
GRADIO_HOST = os.getenv("GRADIO_HOST", "127.0.0.1")
GRADIO_PORT = int(os.getenv("GRADIO_PORT", "7860"))
GRADIO_URL = os.getenv("GRADIO_SERVER_URL", f"http://{GRADIO_HOST}:{GRADIO_PORT}")

# Health check endpoint
@app.get("/health")
async def health_check():
    return {"status": "healthy"}


@app.route("/{path:path}", include_in_schema=False, methods=["GET", "POST"])
async def gradio(request: Request):
    print(f"Forwarding request path: {request.url.path}")

    headers = dict(request.headers)

    # Construct the full URL to Gradio, preserving the original path
    target_url = f"{GRADIO_URL}{request.url.path}"

    async with httpx.AsyncClient() as client:
        response = await client.request(
            request.method,
            target_url,
            headers=headers,
            data=await request.body(),
            params=request.query_params,
        )

        content = await response.aread()
        response_headers = dict(response.headers)
        return StarletteResponse(
            content=content,
            status_code=response.status_code,
            headers=response_headers,
        )

The above code does the following:

  • It initializes a FastAPI application (Which we’ll use to forward requests to our Gradio application running on the same instance (as a subprocess); but on a different port)
  • It sets up a health check endpoint at /health
  • It sets up a catchall proxy that routes all requests to Gradio (notice how we forward the headers through to Gradio too).

Now that we’ve been able to set up our FastAPI application, let’s set up our Gradio application. staying in our main.py, let’s add the following code:

# at the top of our main.py file
import multiprocessing
import sys
import time

import gradio as gr

# Global variable for the Gradio server
gradio_server = None

# Configure the Llama endpoint URL
LLAMA_ENDPOINT = os.getenv("LLAMA_ENDPOINT", "<YOUR_MODEL_API_ENDPOINT>")  # Update with your endpoint


class GradioServer:
    def __init__(self):
        self.host = GRADIO_HOST
        self.port = GRADIO_PORT
        self.process: Optional[multiprocessing.Process] = None
        self.url = GRADIO_URL

    async def chat_with_llama(self, message: str, history: List[List[str]]) -> str:
        """Make a request to the Llama endpoint"""
        # Convert history and new message into OpenAI chat format
        messages = []
        for h in history:
            messages.extend([
                {"role": "user", "content": h[0]},
                {"role": "assistant", "content": h[1]}
            ])
        messages.append({"role": "user", "content": message})

        async with httpx.AsyncClient() as client:
            try:
                response = await client.post(
                    f"{LLAMA_ENDPOINT}/v1/chat/completions",
                    json={
                        "messages": messages,
                        "model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
                        "stream": False,
                        "temperature": 0.7,
                        "top_p": 0.95
                    },
                    timeout=30.0
                )

                if response.status_code == 200:
                    response_data = response.json()
                    return response_data['choices'][0]['text']
                else:
                    return f"Error: Received status code {response.status_code} from Llama endpoint"
            except Exception as e:
                return f"Error communicating with Llama endpoint: {str(e)}"

    def run_server(self):
        interface = gr.ChatInterface(
            fn=self.chat_with_llama,
            type="messages",
            title="Chat with Llama",
            description="This is a chat interface powered by Llama 3.1 8B Instruct",
            examples=[
                ["What is the capital of France?"],
                ["Explain quantum computing in simple terms"],
                ["Write a short poem about technology"]
            ],
        )
        interface.launch(
            server_name=self.host,
            server_port=self.port,
            root_path=f"https://api.cortex.cerebrium.ai/v4/{os.getenv('PROJECT_ID')}/{os.getenv('APP_NAME')}/",
            quiet=True
        )

    def start(self):
        print(f"Starting Gradio server at {self.url} port {self.port}")

        # Start Gradio in a separate process
        self.process = multiprocessing.Process(target=self.run_server)
        self.process.start()

        # Wait for Gradio to become ready
        max_retries = 30
        retry_delay = 1.0

        for _ in range(max_retries):
            try:
                response = requests.get(f"{self.url}/")
                if response.status_code == 200:
                    print(f"Gradio server is ready at {self.url}")
                    return True
            except requests.exceptions.ConnectionError:
                time.sleep(retry_delay)

        print("Failed to start Gradio server")
        self.stop()
        return False

    def stop(self):
        if self.process:
            self.process.terminate()
            self.process.join()
            self.process = None

@app.on_event("startup")
async def startup_event():
    global gradio_server
    if not os.getenv("GRADIO_SERVER_URL"):  # Only start local server if no external URL provided
        gradio_server = GradioServer()
        if not gradio_server.start():
            sys.exit(1)


@app.on_event("shutdown")
async def shutdown_event():
    global gradio_server
    if gradio_server:
        gradio_server.stop()

Above, we have:

  • A class GradioServer that handles the communication with the Llama model endpoint
  • A chat_with_llama method that sends a message to the Llama model and returns the response
  • A run_server method that creates a Gradio chat interface
  • A start method that starts the Gradio server in a separate process
  • A stop method that stops the Gradio server
  • An on_event startup and shutdown event that starts and stops the Gradio server respectively

Finally, your main.py file should look like this:

import multiprocessing
import os
import sys
import time
from typing import Optional, List

import gradio as gr
import httpx
import requests
from fastapi import FastAPI, Request
from starlette.responses import Response as StarletteResponse

# Initialize FastAPI
app = FastAPI()

# Global variable for the Gradio server
gradio_server = None

# Get the Gradio app URL (when running on Cerebrium)
GRADIO_HOST = os.getenv("GRADIO_HOST", "127.0.0.1")
GRADIO_PORT = int(os.getenv("GRADIO_PORT", "7860"))
GRADIO_URL = os.getenv("GRADIO_SERVER_URL", f"http://{GRADIO_HOST}:{GRADIO_PORT}")

# Configure the Llama endpoint URL
LLAMA_ENDPOINT = os.getenv("LLAMA_ENDPOINT", "<YOUR_MODEL_API_ENDPOINT>")  # Update with your endpoint


class GradioServer:
    def __init__(self):
        self.host = GRADIO_HOST
        self.port = GRADIO_PORT
        self.process: Optional[multiprocessing.Process] = None
        self.url = GRADIO_URL

    async def chat_with_llama(self, message: str, history: List[List[str]]) -> str:
        """Make a request to the Llama endpoint"""
        # Convert history and new message into OpenAI chat format
        messages = []
        for h in history:
            messages.extend([
                {"role": "user", "content": h[0]},
                {"role": "assistant", "content": h[1]}
            ])
        messages.append({"role": "user", "content": message})

        async with httpx.AsyncClient() as client:
            try:
                response = await client.post(
                    f"{LLAMA_ENDPOINT}/v1/chat/completions",
                    json={
                        "messages": messages,
                        "model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
                        "stream": False,
                        "temperature": 0.7,
                        "top_p": 0.95
                    },
                    timeout=30.0
                )

                if response.status_code == 200:
                    response_data = response.json()
                    return response_data['choices'][0]['text']
                else:
                    return f"Error: Received status code {response.status_code} from Llama endpoint"
            except Exception as e:
                return f"Error communicating with Llama endpoint: {str(e)}"

    def run_server(self):
        interface = gr.ChatInterface(
            fn=self.chat_with_llama,
            type="messages",
            title="Chat with Llama",
            description="This is a chat interface powered by Llama 3.1 8B Instruct",
            examples=[
                ["What is the capital of France?"],
                ["Explain quantum computing in simple terms"],
                ["Write a short poem about technology"]
            ],
        )
        interface.launch(
            server_name=self.host,
            server_port=self.port,
            root_path=f"https://api.cortex.cerebrium.ai/v4/{os.getenv('PROJECT_ID')}/{os.getenv('APP_NAME')}/",
            quiet=True
        )

    def start(self):
        print(f"Starting Gradio server at {self.url} port {self.port}")

        # Start Gradio in a separate process
        self.process = multiprocessing.Process(target=self.run_server)
        self.process.start()

        # Wait for Gradio to become ready
        max_retries = 30
        retry_delay = 1.0

        for _ in range(max_retries):
            try:
                response = requests.get(f"{self.url}/")
                if response.status_code == 200:
                    print(f"Gradio server is ready at {self.url}")
                    return True
            except requests.exceptions.ConnectionError:
                time.sleep(retry_delay)

        print("Failed to start Gradio server")
        self.stop()
        return False

    def stop(self):
        if self.process:
            self.process.terminate()
            self.process.join()
            self.process = None


# Health check endpoint
@app.get("/health")
async def health_check():
    return {"status": "healthy"}


# Catchall proxy endpoint for Gradio
@app.route("/{path:path}", include_in_schema=False, methods=["GET", "POST"])
async def gradio(request: Request):
    print(f"Forwarding request path: {request.url.path}")

    headers = dict(request.headers)

    # Construct the full URL to Gradio, preserving the original path
    target_url = f"{GRADIO_URL}{request.url.path}"

    async with httpx.AsyncClient() as client:
        response = await client.request(
            request.method,
            target_url,
            headers=headers,
            data=await request.body(),
            params=request.query_params,
        )

        content = await response.aread()
        response_headers = dict(response.headers)
        return StarletteResponse(
            content=content,
            status_code=response.status_code,
            headers=response_headers,
        )


@app.on_event("startup")
async def startup_event():
    global gradio_server
    if not os.getenv("GRADIO_SERVER_URL"):  # Only start local server if no external URL provided
        gradio_server = GradioServer()
        if not gradio_server.start():
            sys.exit(1)


@app.on_event("shutdown")
async def shutdown_event():
    global gradio_server
    if gradio_server:
        gradio_server.stop()

Deploy

Deploy the app use the following command:

cerebrium deploy -y

Once deployed, make the following, navigate to the following URL in your browser:

'https://api.cortex.cerebrium.ai/v4/p-<YOUR PROJECT ID>/34-gradio-interface/'

You should see the Gradio chat interface.

Conclusion

This architecture provides a scalable chat application that efficiently utilizes our new ASGI custom runtime. The separation of frontend and backend services allows for improved performance and cost management while maintaining flexibility for future scaling. We hope that you’ve enjoyed this article. If you have, please feel free to share any feedback, challenges or any of your own Gradio applications in our Discord community.