"""
Triton Python Backend for TensorRT-LLM.
"""
import numpy as np
import triton_python_backend_utils as pb_utils
import torch
from tensorrt_llm import LLM, SamplingParams, BuildConfig
from tensorrt_llm.plugin.plugin import PluginConfig
from transformers import AutoTokenizer
MODEL_ID = "meta-llama/Llama-3.2-3B-Instruct"
MODEL_DIR = f"/persistent-storage/models/{MODEL_ID}"
class TritonPythonModel:
def initialize(self, args):
"""Initialize TensorRT-LLM with PyTorch backend."""
print("Loading tokenizer...")
self.tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR)
print("Initializing TensorRT-LLM...")
plugin_config = PluginConfig.from_dict({
"paged_kv_cache": True,
})
build_config = BuildConfig(
plugin_config=plugin_config,
max_input_len=4096,
max_batch_size=128, # Matches Triton max_batch_size in config.pbtxt
)
self.llm = LLM(
model=MODEL_DIR,
build_config=build_config,
tensor_parallel_size=torch.cuda.device_count(),
)
print("✓ Model ready")
def execute(self, requests):
"""
Execute inference on batched requests.
Triton automatically batches requests (up to max_batch_size: 128).
This function processes the batch that Triton provides.
"""
try:
prompts = []
sampling_params_list = []
original_prompts = []
# Extract data from each request in the batch. We need to look through requests: https://github.com/triton-inference-server/python_backend?tab=readme-ov-file#execute
for request in requests:
try:
# Get input text - handle batched tensor structures
input_tensor = pb_utils.get_input_tensor_by_name(request, "text_input")
text_array = input_tensor.as_numpy()
# Extract text handling different array structures
if text_array.ndim == 0:
text = text_array.item()
elif text_array.dtype == object:
text = text_array.flat[0] if text_array.size > 0 else text_array.item()
else:
text = text_array.flat[0] if text_array.size > 0 else text_array.item()
# Decode if bytes
if isinstance(text, bytes):
text = text.decode('utf-8')
elif isinstance(text, np.str_):
text = str(text)
# Get optional parameters with defaults
max_tokens = 1024
if pb_utils.get_input_tensor_by_name(request, "max_tokens") is not None:
max_tokens_array = pb_utils.get_input_tensor_by_name(request, "max_tokens").as_numpy()
max_tokens = int(max_tokens_array.item() if max_tokens_array.ndim == 0 else max_tokens_array.flat[0])
temperature = 0.8
if pb_utils.get_input_tensor_by_name(request, "temperature") is not None:
temp_array = pb_utils.get_input_tensor_by_name(request, "temperature").as_numpy()
temperature = float(temp_array.item() if temp_array.ndim == 0 else temp_array.flat[0])
top_p = 0.95
if pb_utils.get_input_tensor_by_name(request, "top_p") is not None:
top_p_array = pb_utils.get_input_tensor_by_name(request, "top_p").as_numpy()
top_p = float(top_p_array.item() if top_p_array.ndim == 0 else top_p_array.flat[0])
# Format prompt using chat template
prompt = self.tokenizer.apply_chat_template(
[{"role": "user", "content": text}],
tokenize=False,
add_generation_prompt=True
)
prompts.append(prompt)
original_prompts.append(prompt)
sampling_params_list.append(SamplingParams(
temperature=temperature,
top_p=top_p,
max_tokens=max_tokens,
))
except Exception as e:
print(f"Error processing request: {e}", flush=True)
prompts.append("")
original_prompts.append("")
sampling_params_list.append(SamplingParams(max_tokens=1024))
# Batch inference
if not prompts:
return []
outputs = self.llm.generate(prompts, sampling_params_list)
# Create responses
responses = []
for i, output in enumerate(outputs):
try:
generated_text = output.outputs[0].text
# Strip prompt from output if included
if original_prompts[i] and original_prompts[i] in generated_text:
generated_text = generated_text.replace(original_prompts[i], "").strip()
responses.append(pb_utils.InferenceResponse(
output_tensors=[pb_utils.Tensor(
"text_output",
np.array([generated_text.encode('utf-8')], dtype=object)
)]
))
except Exception as e:
print(f"Error creating response {i}: {e}", flush=True)
responses.append(pb_utils.InferenceResponse(
output_tensors=[pb_utils.Tensor(
"text_output",
np.array([f"Error: {str(e)}".encode('utf-8')], dtype=object)
)]
))
return responses
except Exception as e:
print(f"Error in execute: {e}", flush=True)
return [
pb_utils.InferenceResponse(
output_tensors=[pb_utils.Tensor(
"text_output",
np.array([f"Batch error: {str(e)}".encode('utf-8')], dtype=object)
)]
)
for _ in requests
]
def finalize(self):
"""Cleanup on shutdown."""
if hasattr(self, 'llm'):
self.llm.shutdown()
torch.cuda.empty_cache()