In this tutorial, we will transcribe an hour audio file using Distil Whisper - an optimised version of Whisper-large-v2 but 60% faster and within 1% of the error rate. We will accept either a base64 encode string of the audio file or a url from which we can download the audio file from.

To see the final implementation, you can view it here

Basic Setup

It is important to think of the way you develop models using Cerebrium should be identical to developing on a virtual machine or Google Colab - so converting this should be very easy! Please make sure you have the Cerebrium package installed and have logged in. If not, please take a look at our docs here

First we create our project:

cerebrium init distil-whisper

It is important to think of the way you develop models using Cerebrium should be identical to developing on a virtual machine or Google Colab - so converting this should be very easy!

Let us add the following packages to the [cerebrium.dependencies.pip] section of our cerebrium.toml file:

[cerebrium.dependencies.pip]
accelerate = "latest"
transformers = ">=4.35.0"
openai-whisper

To start let us create a util.py file for our utility functions - downloading a file from a url or converting a base64 string to a file. Our util.py would look something like below:

import base64
import uuid

DOWNLOAD_ROOT = "/tmp/"  # Change this to /persistent-storage/ if you want to save files to the persistent storage

def download_file_from_url(logger, url: str, filename: str):
    logger.info("Downloading file...")

    import requests

    response = requests.get(url)
    if response.status_code == 200:
        logger.info("Download was successful")

        with open(filename, "wb") as f:
            f.write(response.content)

        return filename

    else:
        logger.info(response)
        raise Exception("Download failed")


# Saves a base64 encoded file string to a local file
def save_base64_string_to_file(logger, audio: str):
    logger.info("Converting file...")

    decoded_data = base64.b64decode(audio)

    filename = f"{DOWNLOAD_ROOT}/{uuid.uuid4()}"

    with open(filename, "wb") as file:
        file.write(decoded_data)

    logger.info("Decoding base64 to file was successful")
    return filename

Now that our utility functions are complete, go to the main.py file which will contain our main Python code. We would like the user to send us either a base64 encoded string of the file or a public url from which we can download the file. We would then pass this file to our model and return the output to the user. So let us define our request object.

from typing import Optional
from pydantic import BaseModel, HttpUrl

class Item(BaseModel):
    audio: Optional[str] = None
    file_url: Optional[HttpUrl] = None
    webhook_endpoint: Optional[HttpUrl] = None

Above, we use Pydantic as our data validation library. Due to the way that we have defined the Base Model, “audio” and “file_url” are optional parameters but we must do a check to make sure we are given the one or the other. The webhook_endpoint parameter is something Cerebrium automatically includes in every request and can be used for long running requests. Currently, Cerebrium has a max timeout of 3 minutes for each inference request. For long audio files (2 hours) which take a couple minutes to process it would be best to use a webhook_endpoint which is a url we will make a POST request to with the results of your function.

Setup Model and inference

Below, we import the required packages and load in our Whisper model. This will download during your deployment however in subsequent deploys or inference requests it will be automatically cached in your persistent storage for subsequent use. You can read more about persistent storage here We do this outside our predict function since we only want this code to run on a cold start (ie: on startup). If the container is already warm, we just want it to do inference and it will execute just the predict function.

from huggingface_hub import hf_hub_download
from whisper import load_model, transcribe
from util import download_file_from_url, save_base64_string_to_file

distil_large_v2 = hf_hub_download(repo_id="distil-whisper/distil-large-v2", filename="original-model.bin")
model = load_model(distil_large_v2)

def predict(item, run_id, logger):
    item = Item(**item)
    input_filename = f"{run_id}.mp3"

    if item.audio is not None:
        file = save_base64_string_to_file(logger, item.audio)
    elif item.file_url is not None:
        file = download_file_from_url(logger, item.file_url, input_filename)
    logger.info("Transcribing file...")

    result = transcribe(model, audio=file)

    return result

In our predict function, which only runs on inference requests, we simply create a audio file from the download URL or string given to us via the request. We then transcribe the file and return the output to a user.

Deploy

Your cerebrium.toml file is where you can set your compute/environment. Please make sure that the GPU you specify is a AMPERE_A5000 and that you have enough memory (RAM) on your instance to run the models. You cerebrium.toml file should look like:


[cerebrium.build]
predict_data = "{\"prompt\": \"Here is some example predict data for your cerebrium.toml which will be used to test your predict function on build.\"}"
force_rebuild = false
disable_animation = false
log_level = "INFO"
disable_deployment_confirmation = false

[cerebrium.deployment]
name = "controlnet-logo"
python_version = "3.10"
include = "[./*, main.py]"
exclude = "[./.*, ./__*]"

[cerebrium.hardware]
gpu = "AMPERE_A5000"
cpu = 2
memory = 10.0
gpu_count = 1

[cerebrium.scaling]
min_replicas = 0
cooldown = 60

[cerebrium.dependencies.apt]

[cerebrium.dependencies.pip]
accelerate = "latest"
transformers = ">=4.35.0"
openai-whisper

[cerebrium.dependencies.conda]

To deploy the model use the following command:

cerebrium deploy distill-whisper

Once deployed, we can make the following request:

curl --location 'https://run.cerebrium.ai/v3/p-xxxxx/distill-whisper/predict' \
--header 'Content-Type: application/json' \
--header 'Authorization: <JWT_TOKEN>' \
--data '{"file_url": "https://your-public-url.com/test.mp3"}''

You will notice that you get an immediate response with a 202 status code and a run_id. This run_id is a unique identifier for you to be able to correlate the result to the initial workload.

Our endpoint will then get the following results:

{
  "run_id": "2R5PnHprwNqiS5tcFMor-4c6rSrxuzrVtBU1JfjT5iWFG6s4pHo1Ug==",
  "message": "Finished inference request with run_id: `2R5PnHprwNqiS5tcFMor-4c6rSrxuzrVtBU1JfjT5iWFG6s4pHo1Ug==`",
  "result": {
    "text": " Testing, one, two, three, testing.",
    "segments": [
      {
        "id": 0,
        "seek": 0,
        "start": 0,
        "end": 4,
        "text": " Testing, one, two, three, testing.",
        "tokens": [
          50364, 45517, 11, 472, 11, 220, 20534, 11, 220, 27583, 11, 220, 83,
          8714, 13, 50564
        ],
        "temperature": 0,
        "avg_logprob": -0.3824356023003073,
        "compression_ratio": 1,
        "no_speech_prob": 0.019467202946543694
      }
    ],
    "language": "en"
  },
  "status_code": 200,
  "run_time_ms": 2053.8525581359863
}