| """ |
| Cog prediction script for the PULSE ECG model. |
| |
| This module defines a ``Predictor`` class compatible with the Replicate |
| Cog framework. It delegates model loading and inference to the |
| ``EndpointHandler`` defined in ``handler.py``. The predictor exposes a |
| simple ``predict`` method that accepts an image and a prompt, along with |
| optional sampling parameters. The response is the generated text |
| answer from the model. |
| """ |
|
|
| from typing import Optional |
|
|
| from cog import BasePredictor, Input, Path |
|
|
| from handler import EndpointHandler |
|
|
|
|
| class Predictor(BasePredictor): |
| """Cog predictor for the PULSE ECG model.""" |
|
|
| def setup(self) -> None: |
| """Load the model on startup. |
| |
| Instantiates the ``EndpointHandler``. The underlying model |
| weights and vision tower are loaded during the handler's |
| initialisation; this only happens once when the Cog server |
| starts. |
| """ |
| |
| |
| |
| self.handler = EndpointHandler() |
|
|
| def predict( |
| self, |
| image: Path = Input(description="Input ECG image file"), |
| prompt: str = Input(description="Question to ask about the ECG"), |
| temperature: float = Input( |
| description="Randomness of generation; 0 for deterministic outputs", |
| default=0.0, |
| ge=0.0, |
| ), |
| top_p: float = Input( |
| description="Nucleus sampling parameter; consider tokens in the top p cumulative probability", |
| default=0.9, |
| ge=0.0, |
| le=1.0, |
| ), |
| max_tokens: int = Input( |
| description="Maximum number of new tokens to generate", |
| default=512, |
| ge=0, |
| ), |
| repetition_penalty: float = Input( |
| description="Penalise repetition; 1.0 means no penalty", |
| default=1.0, |
| ge=0.0, |
| ), |
| conv_mode: Optional[str] = Input( |
| description="Override the conversation template (e.g. 'llava_v1')", |
| default=None, |
| ), |
| ) -> str: |
| """Generate a textual response for an ECG image and prompt. |
| |
| Parameters |
| ---------- |
| image: Path |
| Path to the input image file. Cog will save uploaded |
| images to a temporary location and pass the path here. |
| prompt: str |
| The question to ask about the ECG image. |
| temperature: float |
| Sampling temperature; higher values yield more random |
| results. |
| top_p: float |
| Top-p (nucleus) sampling; lower values focus on more |
| likely tokens. |
| max_tokens: int |
| Maximum number of tokens to generate beyond the prompt. |
| repetition_penalty: float |
| Penalty for repeating tokens; values >1.0 discourage |
| repetition. |
| conv_mode: Optional[str] |
| Optional conversation template override. If provided, the |
| handler will use this template instead of inferring one |
| from the model name. |
| |
| Returns |
| ------- |
| str |
| The generated answer from the model. |
| """ |
| |
| |
| |
| event = { |
| "image": str(image), |
| "prompt": prompt, |
| "temperature": temperature, |
| "top_p": top_p, |
| "max_new_tokens": max_tokens, |
| "repetition_penalty": repetition_penalty, |
| } |
| if conv_mode: |
| event["conv_mode"] = conv_mode |
|
|
| |
| |
| |
| result = self.handler(event) |
| if isinstance(result, dict): |
| if "error" in result: |
| raise ValueError(result["error"]) |
| return result.get("generated_text", result.get("answer", "")) |
|
|
| |
| |
| return str(result) |