This example is only compatible with CLI v1.20 and later. Should you be making
use of an older version of the CLI, please run pip install --upgrade cerebrium
to upgrade it to the latest version.
In this tutorial, we’ll show you how to deploy Mistral 7B using the popular vLLM inference framework.
To see the final implementation, you can view it here
Basic Setup
Developing models with Cerebrium is similar to developing on a virtual machine or Google Colab, making conversion straightforward. Make sure you have the Cerebrium package installed and are logged in. If not, check our docs here.
First, create your project:
cerebrium init 1-faster-inference-with-vllm
Add these Python packages to the [cerebrium.dependencies.pip]
section in your cerebrium.toml
file:
[cerebrium.dependencies.pip]
sentencepiece = "latest"
torch = ">=2.0.0"
vllm = "latest"
transformers = ">=4.35.0"
accelerate = "latest"
xformers = "latest"
Create a main.py
file for our Python code. This simple implementation can be done in a single file. First, let’s define our request object:
from pydantic import BaseModel
class Item(BaseModel):
prompt: str
temperature: float
top_p: float
top_k: float
max_tokens: int
frequency_penalty: float
We use Pydantic for data validation. The prompt
parameter is required, while others are optional with default values. If prompt
is missing from the request, users receive an automatic error message.
vLLM Implementation
Model Setup
import os
from vllm import LLM, SamplingParams
from huggingface_hub import login
# Your huggingface token (HF_AUTH_TOKEN) should be stored in your project secrets on your dashboard
login(token=os.environ.get("HF_AUTH_TOKEN"))
# Initialize model with optimized settings
llm = LLM(model="mistralai/Mistral-7B-Instruct-v0.1", dtype="bfloat16", max_model_len=20000, gpu_memory_utilization=0.9)
def predict(prompt, temperature=0.8, top_p=0.75, top_k=40, max_tokens=256, frequency_penalty=1):
item = Item(
prompt=prompt,
temperature=temperature,
top_p=top_p,
top_k=top_k,
max_tokens=max_tokens,
frequency_penalty=frequency_penalty
)
sampling_params = SamplingParams(
temperature=item.temperature,
top_p=item.top_p,
top_k=item.top_k,
max_tokens=item.max_tokens,
frequency_penalty=item.frequency_penalty
)
outputs = llm.generate([item.prompt], sampling_params)
generated_text = []
for output in outputs:
generated_text.append(output.outputs[0].text)
return {"result": generated_text}
We load the model outside the predict
function since it only needs to be loaded once at startup, not with every request. The predict
function simply passes input parameters from the request to the model and returns the generated outputs.
Deploy
Configure your compute and environment settings in cerebrium.toml
:
[cerebrium.build]
predict_data = "{\"prompt\": \"Here is some example predict data for your config.yaml which will be used to test your predict function on build.\"}"
hide_public_endpoint = false
disable_animation = false
disable_build_logs = false
disable_syntax_check = false
disable_predict = false
log_level = "INFO"
disable_confirmation = false
[cerebrium.deployment]
name = "1-faster-inference-with-vllm"
python_version = "3.11"
include = ["./*", "main.py", "cerebrium.toml"]
exclude = ["./example_exclude"]
docker_base_image_url = "nvidia/cuda:12.1.1-runtime-ubuntu22.04"
[cerebrium.hardware]
region = "us-east-1"
provider = "aws"
compute = "AMPERE_A10"
cpu = 2
memory = 16.0
gpu_count = 1
[cerebrium.scaling]
min_replicas = 0
max_replicas = 5
cooldown = 60
[cerebrium.dependencies.pip]
huggingface-hub = "latest"
sentencepiece = "latest"
torch = ">=2.0.0"
vllm = "latest"
transformers = ">=4.35.0"
accelerate = "latest"
xformers = "latest"
[cerebrium.dependencies.conda]
[cerebrium.dependencies.apt]
ffmpeg = "latest"
Deploy the model using this command:
After deployment, make this request:
curl --location --request POST 'https://api.cortex.cerebrium.ai/v4/p-<YOUR PROJECT ID>/1-faster-inference-with-vllm/predict' \
--header 'Authorization: Bearer <YOUR TOKEN HERE>' \
--header 'Content-Type: application/json' \
--data-raw '{
"prompt": "What is the capital city of France?"
}'
The endpoint returns results in this format:
{
"run_id": "nZL6mD8q66u4lHTXcqmPCc6pxxFwn95IfqQvEix0gHaOH4gkHUdz1w==",
"message": "Finished inference request with run_id: `nZL6mD8q66u4lHTXcqmPCc6pxxFwn95IfqQvEix0gHaOH4gkHUdz1w==`",
"result": {
"result": ["\nA: Paris"]
},
"status_code": 200,
"run_time_ms": 151.24988555908203
}