In this tutorial, we will show you how to deploy Mistral 7B using the popular vLLM inference framework.

To see the final implementation, you can view it here

Basic Setup

It is important to think of the way you develop models using Cerebrium should be identical to developing on a virtual machine or Google Colab - so converting this should be very easy! Please make sure you have the Cerebrium package installed and have logged in. If not, please take a look at our docs here

First we create our project:

cerebrium init mistral-vllm

We need certain Python packages to implement this project. Lets add those to our [cerebrium.dependencies.pip] in our cerebrium.toml file:

sentencepiece = "latest"
torch = ">=2.0.0"
vllm = "latest"
transformers = ">=4.35.0"
accelerate = "latest"
xformers = "latest"

Our file will contain our main Python code. This is a relatively simple implementation, so we can do everything in 1 file. We would like a user to send in a link to a YouTube video with a question and return to them the answer as well as the time segment of where we got that response. So let us define our request object.

from pydantic import BaseModel

class Item(BaseModel):
    prompt: str
    temperature: Optional[float] = 0.8
    top_p: Optional[float] = 0.75
    top_k: Optional[float] = 40
    max_tokens: Optional[int] = 256
    frequency_penalty: Optional[float] = 1

Above, we use Pydantic as our data validation library. We specify the parameters that are required as well as the parameters that are not (ie: using the Optional keyword) as well as assign defaults to some values. Prompt is the only required parameter so if it is not present in the request, the user will automatically receive an error message.

vLLM Implementation

Below, we will use the Whisper model from OpenAI to convert the video audio to text. We will then split the text into its phrase segments with its respective timings, so we know the exact source of where our model got the answer from.

import torch
from vllm import LLM, SamplingParams

llm = LLM(model="mistralai/Mistral-7B-Instruct-v0.1", dtype="bfloat16")

def predict(item, run_id, logger):
    item = Item(**item)

    # Now just setup your sampling parameters for inference:
    sampling_params = SamplingParams(temperature=item.temperature, top_p=item.top_p, top_k=item.top_k, max_tokens=item.max_tokens, frequency_penalty=item.frequency_penalty)

    # And feed your prompt and sampling params into your LLM pipeline as follows.
    outputs = llm.generate([item.prompt], sampling_params)

    # Extract your text outputs:
    generated_text = []
    for output in outputs:

    # And return the result
    return {"result": generated_text}

We load the model outside the predict function. The reason for this is that the API request will run the predict function every time, and we don’t want to load our model in every request as that takes time. The code outside the predict function will run on model startup ie: when the model is cold.

The implementation in our predict function is pretty straight forward in that we pass input parameters from our request into the model and then generate outputs that we return to the user.


Your cerebrium.toml file is where you can set your compute/environment. Please make sure that the GPU you specify is an AMPERE_A5000, and that you have enough memory (RAM) on your instance to run the models. Your cerebrium.toml file should look like:

predict_data = "{\"prompt\": \"Here is some example predict data for your cerebrium.toml which will be used to test your predict function on build.\"}"
force_rebuild = false
disable_animation = false
log_level = "INFO"
disable_deployment_confirmation = false

name = "mistral-vllm"
python_version = "3.10"
include = "[./*,]"
exclude = "[./.*, ./__*]"

gpu = "AMPERE_A5000"
cpu = 2
memory = 16.0
gpu_count = 1

min_replicas = 0
cooldown = 60

ffmpeg = "latest"

sentencepiece = "latest"
torch = ">=2.0.0"
vllm = "latest"
transformers = ">=4.35.0"
accelerate = "latest"
xformers = "latest"


To deploy the model use the following command:

cerebrium deploy

Once deployed, we can make the following request:

curl --location --request POST '' \
--header 'Authorization: <JWT-TOKEN>' \
--header 'Content-Type: application/json' \
--data-raw '{
    "prompt: "What is the capital city of France?"

We then get the following results:

  "run_id": "nZL6mD8q66u4lHTXcqmPCc6pxxFwn95IfqQvEix0gHaOH4gkHUdz1w==",
  "message": "Finished inference request with run_id: `nZL6mD8q66u4lHTXcqmPCc6pxxFwn95IfqQvEix0gHaOH4gkHUdz1w==`",
  "result": {
    "result": ["\nA: Paris"]
  "status_code": 200,
  "run_time_ms": 151.24988555908203