In this tutorial, we’ll create and deploy a Gradio chat interface that connects to a Llama 8B language model using Cerebrium’s custom ASGI runtime. We’ll build a scalable architecture where the frontend runs on CPU instances, while the model runs separately on GPU instances for optimal resource utilization.
You can find the full codebase for deploying your Gradio frontend here.
Architecture Overview
Our application consists of two main components:
- A frontend interface running on CPU instances using FastAPI and Gradio.
- A separate Llama model endpoint running on GPU instances. (While this is beyond the scope of this article, you can find a comprehensive example for deploying Llama 8B with TensorRT here.)
This separation allows us to:
- Keep the frontend always available while minimizing costs (CPU-only).
- Scale our GPU-intensive model independently based on demand.
- Optimize resource allocation for different components.
Prerequisites
Before starting, you’ll need:
- A Cerebrium account (sign up here).
- The Cerebrium CLI installed:
pip install --upgrade cerebrium
.
- A Llama model endpoint (or other LLM API endpoint).
Basic Setup
First, create a new directory for your project and initialize it:
cerebrium init 2-gradio-interface
Next, let us add the following configuration to our cerebrium.toml
file:
[cerebrium.deployment]
name = "2-gradio-interface"
python_version = "3.12"
disable_auth = true
include = ['./*', 'main.py', 'cerebrium.toml']
exclude = ['.*']
[cerebrium.runtime.custom]
entrypoint = ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8080"]
port = 8080
healthcheck_endpoint = "/health"
[cerebrium.hardware]
cpu = 2
memory = 4.0
compute = "CPU"
[cerebrium.scaling]
min_replicas = 0
max_replicas = 2
cooldown = 30
replica_concurrency = 10
[cerebrium.dependencies.pip]
gradio = "latest"
fastapi = "latest"
requests = "latest"
httpx = "latest"
uvicorn = "latest"
starlette = "latest"
This configuration does several things:
- Disables the default JWT authentication that is automatically placed on all Cerebrium endpoints, making your Gradio interface publicly accessible.
- Sets the entrypoint for the ASGI server to run through Uvicorn.
- Sets the default port to 8080 for serving your app.
- Sets the health endpoint to
/health
for checking app availability through our FastAPI application.
- Configures hardware settings for the CPU instance running your app.
- Defines scaling configuration with minimum and maximum replicas, cooldown period, and replica concurrency (set to 10 requests per replica).
- Specifies required dependencies: Gradio, FastAPI, Requests, HTTPX, Uvicorn, and Starlette.
Now, let’s set up our main entrypoint file (main.py
). To start, let’s create our FastAPI application:
# at the top of your main.py file
import requests
from typing import Optional, List
import httpx
from fastapi import FastAPI, Request
from starlette.responses import Response as StarletteResponse
app = FastAPI()
# Get the Gradio app URL (when running on Cerebrium)
GRADIO_HOST = os.getenv("GRADIO_HOST", "127.0.0.1")
GRADIO_PORT = int(os.getenv("GRADIO_PORT", "7860"))
GRADIO_URL = os.getenv("GRADIO_SERVER_URL", f"http://{GRADIO_HOST}:{GRADIO_PORT}")
# Health check endpoint
@app.get("/health")
async def health_check():
return {"status": "healthy"}
@app.route("/{path:path}", include_in_schema=False, methods=["GET", "POST"])
async def gradio(request: Request):
print(f"Forwarding request path: {request.url.path}")
headers = dict(request.headers)
# Construct the full URL to Gradio, preserving the original path
target_url = f"{GRADIO_URL}{request.url.path}"
async with httpx.AsyncClient() as client:
response = await client.request(
request.method,
target_url,
headers=headers,
data=await request.body(),
params=request.query_params,
)
content = await response.aread()
response_headers = dict(response.headers)
return StarletteResponse(
content=content,
status_code=response.status_code,
headers=response_headers,
)
The above code:
- Initializes a FastAPI application to forward requests to our Gradio app running as a subprocess on a different port.
- Sets up a health check endpoint at
/health
.
- Creates a catchall proxy that routes all requests to Gradio, including headers.
Now that we’ve set up our FastAPI application, let’s set up our Gradio application. Staying in our main.py
, let’s add the following code:
import multiprocessing
import os
import sys
import time
import gradio as gr
# Global variable for the Gradio server
gradio_server = None
# Configure the Llama endpoint URL
LLAMA_ENDPOINT = os.getenv("LLAMA_ENDPOINT", "<YOUR_MODEL_API_ENDPOINT>") # Update with your endpoint
class GradioServer:
def __init__(self):
self.host = GRADIO_HOST
self.port = GRADIO_PORT
self.process: Optional[multiprocessing.Process] = None
self.url = GRADIO_URL
async def chat_with_llama(self, message: str, history: List[List[str]]) -> str:
"""Make a request to the Llama endpoint"""
# Convert history and new message into OpenAI chat format
messages = []
for h in history:
messages.extend([
{"role": "user", "content": h[0]},
{"role": "assistant", "content": h[1]}
])
messages.append({"role": "user", "content": message})
async with httpx.AsyncClient() as client:
try:
response = await client.post(
f"{LLAMA_ENDPOINT}/v1/chat/completions",
json={
"messages": messages,
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"stream": False,
"temperature": 0.7,
"top_p": 0.95
},
timeout=30.0
)
if response.status_code == 200:
response_data = response.json()
return response_data['choices'][0]['text']
else:
return f"Error: Received status code {response.status_code} from Llama endpoint"
except Exception as e:
return f"Error communicating with Llama endpoint: {str(e)}"
def run_server(self):
interface = gr.ChatInterface(
fn=self.chat_with_llama,
type="messages",
title="Chat with Llama",
description="This is a chat interface powered by Llama 3.1 8B Instruct",
examples=[
["What is the capital of France?"],
["Explain quantum computing in simple terms"],
["Write a short poem about technology"]
],
)
interface.launch(
server_name=self.host,
server_port=self.port,
root_path=f"https://api.cortex.cerebrium.ai/v4/{os.getenv('PROJECT_ID')}/{os.getenv('APP_NAME')}/",
quiet=True
)
def start(self):
print(f"Starting Gradio server at {self.url} port {self.port}")
# Start Gradio in a separate process
self.process = multiprocessing.Process(target=self.run_server)
self.process.start()
# Wait for Gradio to become ready
max_retries = 30
retry_delay = 1.0
for _ in range(max_retries):
try:
response = requests.get(f"{self.url}/")
if response.status_code == 200:
print(f"Gradio server is ready at {self.url}")
return True
except requests.exceptions.ConnectionError:
time.sleep(retry_delay)
print("Failed to start Gradio server")
self.stop()
return False
def stop(self):
if self.process:
self.process.terminate()
self.process.join()
self.process = None
@app.on_event("startup")
async def startup_event():
global gradio_server
if not os.getenv("GRADIO_SERVER_URL"): # Only start local server if no external URL provided
gradio_server = GradioServer()
if not gradio_server.start():
sys.exit(1)
@app.on_event("shutdown")
async def shutdown_event():
global gradio_server
if gradio_server:
gradio_server.stop()
Above, we have:
- A class
GradioServer
that handles the communication with the Llama model endpoint
- A
chat_with_llama
method that sends a message to the Llama model and returns the response
- A
run_server
method that creates a Gradio chat interface
- A
start
method that starts the Gradio server in a separate process
- A
stop
method that stops the Gradio server
- An
on_event
startup and shutdown event that starts and stops the Gradio server respectively
Finally, your main.py
file should look like this:
import multiprocessing
import os
import sys
import time
from typing import Optional, List
import gradio as gr
import httpx
import requests
from fastapi import FastAPI, Request
from starlette.responses import Response as StarletteResponse
# Initialize FastAPI
app = FastAPI()
# Server configuration
gradio_server = None
GRADIO_HOST = os.getenv("GRADIO_HOST", "127.0.0.1")
GRADIO_PORT = int(os.getenv("GRADIO_PORT", "7860"))
GRADIO_URL = os.getenv("GRADIO_SERVER_URL", f"http://{GRADIO_HOST}:{GRADIO_PORT}")
LLAMA_ENDPOINT = os.getenv("LLAMA_ENDPOINT", "<YOUR_MODEL_API_ENDPOINT>")
class GradioServer:
def __init__(self):
self.host = GRADIO_HOST
self.port = GRADIO_PORT
self.process: Optional[multiprocessing.Process] = None
self.url = GRADIO_URL
async def chat_with_llama(self, message: str, history: List[List[str]]) -> str:
"""Make a request to the Llama endpoint"""
# Convert history and new message into OpenAI chat format
messages = []
for h in history:
messages.extend([
{"role": "user", "content": h[0]},
{"role": "assistant", "content": h[1]}
])
messages.append({"role": "user", "content": message})
async with httpx.AsyncClient() as client:
try:
response = await client.post(
f"{LLAMA_ENDPOINT}/v1/chat/completions",
json={
"messages": messages,
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"stream": False,
"temperature": 0.7,
"top_p": 0.95
},
timeout=30.0
)
if response.status_code == 200:
response_data = response.json()
return response_data['choices'][0]['text']
else:
return f"Error: Received status code {response.status_code} from Llama endpoint"
except Exception as e:
return f"Error communicating with Llama endpoint: {str(e)}"
def run_server(self):
interface = gr.ChatInterface(
fn=self.chat_with_llama,
type="messages",
title="Chat with Llama",
description="This is a chat interface powered by Llama 3.1 8B Instruct",
examples=[
["What is the capital of France?"],
["Explain quantum computing in simple terms"],
["Write a short poem about technology"]
],
)
interface.launch(
server_name=self.host,
server_port=self.port,
root_path=f"https://api.cortex.cerebrium.ai/v4/{os.getenv('PROJECT_ID')}/{os.getenv('APP_NAME')}/",
quiet=True
)
def start(self):
print(f"Starting Gradio server at {self.url} port {self.port}")
# Start Gradio in a separate process
self.process = multiprocessing.Process(target=self.run_server)
self.process.start()
# Wait for Gradio to become ready
max_retries = 30
retry_delay = 1.0
for _ in range(max_retries):
try:
response = requests.get(f"{self.url}/")
if response.status_code == 200:
print(f"Gradio server is ready at {self.url}")
return True
except requests.exceptions.ConnectionError:
time.sleep(retry_delay)
print("Failed to start Gradio server")
self.stop()
return False
def stop(self):
if self.process:
self.process.terminate()
self.process.join()
self.process = None
@app.get("/health")
async def health_check():
return {"status": "healthy"}
# Catchall proxy endpoint for Gradio
@app.route("/{path:path}", include_in_schema=False, methods=["GET", "POST"])
async def gradio(request: Request):
print(f"Forwarding request path: {request.url.path}")
headers = dict(request.headers)
# Construct the full URL to Gradio, preserving the original path
target_url = f"{GRADIO_URL}{request.url.path}"
async with httpx.AsyncClient() as client:
response = await client.request(
request.method,
target_url,
headers=headers,
data=await request.body(),
params=request.query_params,
)
content = await response.aread()
response_headers = dict(response.headers)
return StarletteResponse(
content=content,
status_code=response.status_code,
headers=response_headers,
)
@app.on_event("startup")
async def startup_event():
global gradio_server
if not os.getenv("GRADIO_SERVER_URL"): # Only start local server if no external URL provided
gradio_server = GradioServer()
if not gradio_server.start():
sys.exit(1)
@app.on_event("shutdown")
async def shutdown_event():
global gradio_server
if gradio_server:
gradio_server.stop()
Deploy
Deploy the app use the following command:
Once deployed, navigate to the following URL in your browser:
https://api.cortex.cerebrium.ai/v4/p-<YOUR_PROJECT_ID>/2-gradio-interface/
You should see the Gradio chat interface.
Conclusion
This architecture provides a scalable chat app that efficiently utilizes our new ASGI custom runtime. The separation of frontend and backend services allows for improved performance and cost management while maintaining flexibility for future scaling. We hope you’ve enjoyed this tutorial. Please feel free to share feedback, challenges, or your own Gradio apps in our Discord community.