"""Custom inference handler for Hugging Face Inference Endpoints. NLLB needs a source-language code on the tokenizer and a forced BOS token id for the target language at generation time, so the default translation pipeline is not flexible enough. This handler accepts `src_lang` and `tgt_lang` (NLLB Flores-200 codes, e.g. "eng_Latn", "spa_Latn") per request. Request format: { "inputs": "Hello, world!", # str or List[str] "parameters": { "src_lang": "eng_Latn", # optional, default eng_Latn "tgt_lang": "spa_Latn", # optional, default spa_Latn "max_length": 256, # optional "num_beams": 4, # optional "temperature": 1.0, # optional "do_sample": false # optional } } Response: List[{"translation_text": str}] """ from __future__ import annotations from typing import Any, Dict, List, Union import torch from transformers import AutoModelForSeq2SeqLM, AutoTokenizer DEFAULT_SRC_LANG = "eng_Latn" DEFAULT_TGT_LANG = "spa_Latn" DEFAULT_MAX_LENGTH = 256 DEFAULT_NUM_BEAMS = 4 class EndpointHandler: def __init__(self, path: str = "") -> None: self.device = "cuda" if torch.cuda.is_available() else "cpu" # fp16 on GPU keeps latency and memory down; stay in fp32 on CPU for stability. dtype = torch.float16 if self.device == "cuda" else torch.float32 self.tokenizer = AutoTokenizer.from_pretrained(path) self.model = AutoModelForSeq2SeqLM.from_pretrained( path, torch_dtype=dtype ).to(self.device) self.model.eval() def __call__( self, data: Dict[str, Any] ) -> List[Dict[str, str]]: inputs: Union[str, List[str], None] = data.get("inputs") if inputs is None: return [{"error": "Missing 'inputs' field."}] if isinstance(inputs, str): inputs = [inputs] if not all(isinstance(x, str) for x in inputs): return [{"error": "'inputs' must be a string or a list of strings."}] params: Dict[str, Any] = data.get("parameters") or {} src_lang = params.get("src_lang", DEFAULT_SRC_LANG) tgt_lang = params.get("tgt_lang", DEFAULT_TGT_LANG) max_length = int(params.get("max_length", DEFAULT_MAX_LENGTH)) num_beams = int(params.get("num_beams", DEFAULT_NUM_BEAMS)) do_sample = bool(params.get("do_sample", False)) temperature = float(params.get("temperature", 1.0)) try: forced_bos_token_id = self.tokenizer.convert_tokens_to_ids(tgt_lang) except Exception: return [{"error": f"Unknown target language code: {tgt_lang!r}"}] if forced_bos_token_id == self.tokenizer.unk_token_id: return [{"error": f"Unknown target language code: {tgt_lang!r}"}] self.tokenizer.src_lang = src_lang encoded = self.tokenizer( inputs, return_tensors="pt", padding=True, truncation=True, max_length=max_length, ).to(self.device) with torch.inference_mode(): generated = self.model.generate( **encoded, forced_bos_token_id=forced_bos_token_id, max_length=max_length, num_beams=num_beams, do_sample=do_sample, temperature=temperature, ) decoded = self.tokenizer.batch_decode(generated, skip_special_tokens=True) return [{"translation_text": t} for t in decoded]