| from typing import Dict, List, Any |
| from transformers import DonutProcessor, VisionEncoderDecoderModel |
| import torch |
|
|
|
|
| |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
| class EndpointHandler: |
| def __init__(self, path=""): |
| |
| self.processor = DonutProcessor.from_pretrained(path) |
| self.model = VisionEncoderDecoderModel.from_pretrained(path) |
| |
| self.model.to(device) |
| self.decoder_input_ids = self.processor.tokenizer( |
| "<s_cord-v2>", add_special_tokens=False, return_tensors="pt" |
| ).input_ids |
|
|
| def __call__(self, data: Any) -> List[List[Dict[str, float]]]: |
|
|
| inputs = data.pop("inputs", data) |
|
|
|
|
| |
| pixel_values = self.processor(inputs, return_tensors="pt").pixel_values |
|
|
| |
| outputs = self.model.generate( |
| pixel_values.to(device), |
| decoder_input_ids=self.decoder_input_ids.to(device), |
| max_length=self.model.decoder.config.max_position_embeddings, |
| early_stopping=True, |
| pad_token_id=self.processor.tokenizer.pad_token_id, |
| eos_token_id=self.processor.tokenizer.eos_token_id, |
| use_cache=True, |
| num_beams=1, |
| bad_words_ids=[[self.processor.tokenizer.unk_token_id]], |
| return_dict_in_generate=True, |
| ) |
| |
| prediction = self.processor.batch_decode(outputs.sequences)[0] |
| prediction = self.processor.token2json(prediction) |
|
|
| return prediction |
|
|