| | |
| | |
| | ''' |
| | @license: (C) Copyright 2025, Hey. |
| | @author: Hey |
| | @email: sanyuan.hy@alibaba-inc.com |
| | @tel: 137****6540 |
| | @datetime: 2025/12/30 11:33 |
| | @project: lucaone |
| | @file: tokenization_lucaone |
| | @desc: tokenization_lucaone |
| | ''' |
| |
|
| | import os |
| | import json |
| | import itertools |
| | from typing import List, Optional, Dict, Any, Tuple, Union |
| | from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast |
| |
|
| | def gene_seq_replace(seq): |
| | """ |
| | Gene sequence preprocessing: A->1, U/T->2, C->3, G->4, N->5 |
| | Optimized for performance. |
| | """ |
| | |
| | mapping = { |
| | 'A': '1', 'a': '1', |
| | 'T': '2', 't': '2', 'U': '2', 'u': '2', |
| | 'C': '3', 'c': '3', |
| | 'G': '4', 'g': '4' |
| | } |
| | |
| | return "".join([mapping.get(ch, '5') for ch in seq]) |
| |
|
| | class LucaGPLMTokenizer(PreTrainedTokenizer): |
| | """ |
| | HuggingFace-compatible tokenizer that performs identical tokenization |
| | to the old model's Alphabet class. |
| | """ |
| | |
| | |
| | gene_prepend_toks = ['[PAD]', '[UNK]'] |
| | gene_append_toks = ['[CLS]', '[SEP]', '[MASK]'] |
| | gene_standard_toks = ['1', '2', '3', '4', '5', '.', '-', '*'] |
| | |
| | prot_prepend_toks = ['[PAD]', '[UNK]'] |
| | prot_append_toks = ['[CLS]', '[SEP]', '[MASK]'] |
| | prot_standard_toks = ['L', 'A', 'G', 'V', 'S', 'E', 'R', 'T', 'I', 'D', 'P', 'K', 'Q', 'N', 'F', 'Y', 'M', 'H', 'W', 'C', 'X', 'B', 'U', 'Z', 'O', 'J', '.', '-', '*'] |
| | |
| | gene_prot_prepend_toks = ['[PAD]', '[UNK]'] |
| | gene_prot_append_toks = ['[CLS]', '[SEP]', '[MASK]'] |
| | |
| |
|
| | gene_prot_standard_toks = [ |
| | '1', |
| | '2', |
| | '3', |
| | '4', |
| | '5', |
| | 'L', |
| | 'A', |
| | 'G', |
| | 'V', |
| | 'S', |
| | 'E', |
| | 'R', |
| | 'T', |
| | 'I', |
| | 'D', |
| | 'P', |
| | 'K', |
| | 'Q', |
| | 'N', |
| | 'F', |
| | 'Y', |
| | 'M', |
| | 'H', |
| | 'W', |
| | 'C', |
| | 'X', |
| | 'B', |
| | 'U', |
| | 'Z', |
| | 'O', |
| | 'J', |
| | '.', |
| | '-', |
| | '*' |
| | ] |
| |
|
| | def __init__( |
| | self, |
| | vocab_type: str = "gene_prot", |
| | prepend_bos: bool = True, |
| | append_eos: bool = True, |
| | unk_token="[UNK]", |
| | pad_token="[PAD]", |
| | cls_token="[CLS]", |
| | sep_token="[SEP]", |
| | mask_token="[MASK]", |
| | **kwargs |
| | ): |
| | |
| | if vocab_type.lower() == "prot": |
| | prepend_toks = self.prot_prepend_toks |
| | append_toks = self.prot_append_toks |
| | standard_toks = self.prot_standard_toks |
| | elif vocab_type.lower() == "gene": |
| | prepend_toks = self.gene_prepend_toks |
| | append_toks = self.gene_append_toks |
| | standard_toks = self.gene_standard_toks |
| | elif vocab_type.lower() in ["gene_prot", "prot_gene"]: |
| | prepend_toks = self.gene_prot_prepend_toks |
| | append_toks = self.gene_prot_append_toks |
| | standard_toks = self.gene_prot_standard_toks |
| | else: |
| | raise ValueError(f"Not support tokenizer vocab_type: {vocab_type}") |
| | |
| | |
| | self.all_toks = list(prepend_toks) + list(append_toks) + list(standard_toks) |
| | self.tok_to_idx = {tok: i for i, tok in enumerate(self.all_toks)} |
| | self.idx_to_tok = {i: tok for i, tok in enumerate(self.all_toks)} |
| | |
| | |
| | self.vocab_type = vocab_type |
| | self.prepend_bos = prepend_bos |
| | self.append_eos = append_eos |
| | self.unique_no_split_tokens = self.all_toks.copy() |
| | |
| | |
| | self.unk_idx = self.tok_to_idx.get("[UNK]", 1) |
| | self.padding_idx = self.tok_to_idx.get("[PAD]", 0) |
| | self.cls_idx = self.tok_to_idx.get("[CLS]", 2) |
| | self.mask_idx = self.tok_to_idx.get("[MASK]", 4) |
| | self.eos_idx = self.tok_to_idx.get("[SEP]", 3) |
| |
|
| | super().__init__( |
| | unk_token=unk_token, |
| | pad_token=pad_token, |
| | cls_token=cls_token, |
| | sep_token=sep_token, |
| | mask_token=mask_token, |
| | **kwargs |
| | ) |
| |
|
| | def get_vocab(self) -> Dict[str, int]: |
| | return self.tok_to_idx.copy() |
| |
|
| | @property |
| | def vocab_size(self) -> int: |
| | return len(self.all_toks) |
| |
|
| | def get_idx(self, tok): |
| | return self.tok_to_idx.get(tok, self.unk_idx) |
| |
|
| | def get_tok(self, idx): |
| | return self.idx_to_tok.get(idx, "[UNK]") |
| |
|
| | def _tokenize_char_level(self, text: str) -> List[str]: |
| | """Simple character-level tokenization (fallback)""" |
| | return list(text) |
| |
|
| | def _tokenize(self, text: str) -> List[str]: |
| | """ |
| | Tokenize text using the same logic as the old Alphabet.tokenize() method |
| | """ |
| | text = text.strip() |
| | if not text: |
| | return [] |
| | |
| | return list(text) |
| |
|
| | def _convert_token_to_id(self, token: str) -> int: |
| | return self.get_idx(token) |
| |
|
| | def _convert_id_to_token(self, index: int) -> str: |
| | return self.get_tok(index) |
| |
|
| | def convert_tokens_to_string(self, tokens: List[str]) -> str: |
| | return "".join(tokens) |
| |
|
| | def _convert_text_to_ids(self, text: str, seq_type: str) -> List[int]: |
| | """Internal helper to convert text to IDs without special tokens.""" |
| | if seq_type == "gene": |
| | text = gene_seq_replace(text) |
| | tokens = self._tokenize(text) |
| | return [self._convert_token_to_id(token) for token in tokens] |
| |
|
| | def build_inputs_with_special_tokens(self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None) -> List[int]: |
| | """ |
| | Build model inputs from a sequence by adding special tokens. |
| | This mimics the old model's prepend_bos and append_eos behavior. |
| | """ |
| | result = token_ids_0.copy() |
| | |
| | if self.prepend_bos: |
| | result = [self.cls_idx] + result |
| | if self.append_eos: |
| | result = result + [self.eos_idx] |
| | |
| | return result |
| |
|
| | def get_special_tokens_mask( |
| | self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False |
| | ) -> List[int]: |
| | """ |
| | Retrieve sequence ids from a token list. |
| | """ |
| | if already_has_special_tokens: |
| | return super().get_special_tokens_mask( |
| | token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True |
| | ) |
| |
|
| | result = [0] * len(token_ids_0) |
| | if self.prepend_bos: |
| | result = [1] + result |
| | if self.append_eos: |
| | result = result + [1] |
| | return result |
| |
|
| | def encode( |
| | self, |
| | text: str, |
| | seq_type: str = "gene", |
| | add_special_tokens: bool = True, |
| | padding: Union[bool, str] = False, |
| | truncation: bool = False, |
| | max_length: Optional[int] = None, |
| | **kwargs |
| | ) -> List[int]: |
| | |
| | |
| | token_ids = self._convert_text_to_ids(text, seq_type) |
| | |
| | |
| | if add_special_tokens: |
| | token_ids = self.build_inputs_with_special_tokens(token_ids) |
| | |
| | |
| | if truncation and max_length is not None and len(token_ids) > max_length: |
| | token_ids = token_ids[:max_length] |
| | |
| | if add_special_tokens and self.append_eos: |
| | token_ids[-1] = self.eos_idx |
| | |
| | return token_ids |
| |
|
| | def __call__( |
| | self, |
| | text: Union[str, List[str]], |
| | text_pair: Optional[Union[str, List[str]]] = None, |
| | seq_type: str = "gene", |
| | add_special_tokens: bool = True, |
| | padding: Union[bool, str] = False, |
| | max_length: Optional[int] = None, |
| | return_attention_mask: bool = True, |
| | return_token_type_ids: bool = True, |
| | return_tensors: Optional[str] = None, |
| | truncation: bool = False, |
| | **kwargs |
| | ) -> Union[Dict[str, Any], List[Dict[str, Any]]]: |
| | """ |
| | Main callable method for tokenization - HuggingFace standard interface |
| | """ |
| | if isinstance(text, list): |
| | |
| | return self.batch_encode_plus( |
| | text, |
| | text_pair=text_pair, |
| | seq_type=seq_type, |
| | add_special_tokens=add_special_tokens, |
| | padding=padding, |
| | max_length=max_length, |
| | return_attention_mask=return_attention_mask, |
| | return_token_type_ids=return_token_type_ids, |
| | return_tensors=return_tensors, |
| | truncation=truncation, |
| | **kwargs |
| | ) |
| | else: |
| | |
| | return self.encode_plus( |
| | text, |
| | text_pair=text_pair, |
| | seq_type=seq_type, |
| | add_special_tokens=add_special_tokens, |
| | padding=padding, |
| | max_length=max_length, |
| | return_attention_mask=return_attention_mask, |
| | return_token_type_ids=return_token_type_ids, |
| | return_tensors=return_tensors, |
| | truncation=truncation, |
| | **kwargs |
| | ) |
| |
|
| | def batch_encode_plus(self, *args, **kwargs): |
| | |
| | |
| | |
| | |
| | batch_outputs = [] |
| | batch_text = kwargs["text"] |
| | seq_type = kwargs["seq_type"] |
| | for text in batch_text: |
| | batch_outputs.append(self.encode_plus(text, seq_type=seq_type, **kwargs)) |
| |
|
| | |
| | |
| | combined = {key: [] for key in batch_outputs[0].keys()} |
| | for output in batch_outputs: |
| | for key, value in output.items(): |
| | combined[key].append(value) |
| |
|
| | return combined |
| |
|
| | def encode_plus( |
| | self, |
| | text: str, |
| | text_pair: Optional[str] = None, |
| | seq_type: str = "gene", |
| | add_special_tokens: bool = True, |
| | padding: Union[bool, str] = False, |
| | max_length: Optional[int] = None, |
| | return_attention_mask: bool = True, |
| | return_token_type_ids: bool = True, |
| | return_tensors: Optional[str] = None, |
| | truncation: bool = False, |
| | **kwargs |
| | ) -> Dict[str, Any]: |
| | |
| | |
| | kwargs.pop("text_pair", None) |
| | |
| | token_ids = self.encode( |
| | text, |
| | seq_type=seq_type, |
| | add_special_tokens=add_special_tokens, |
| | truncation=truncation, |
| | max_length=max_length |
| | ) |
| | |
| | |
| | attention_mask = [1] * len(token_ids) |
| | if padding == "max_length" and max_length is not None: |
| | if len(token_ids) < max_length: |
| | pad_length = max_length - len(token_ids) |
| | token_ids.extend([self.padding_idx] * pad_length) |
| | attention_mask.extend([0] * pad_length) |
| | |
| | |
| | result = {"input_ids": token_ids} |
| | |
| | if return_attention_mask: |
| | result["attention_mask"] = attention_mask |
| | |
| | if return_token_type_ids: |
| | |
| | type_value = 0 if seq_type == "gene" else 1 |
| | result["token_type_ids"] = [type_value] * len(token_ids) |
| | |
| | if return_tensors == "pt": |
| | import torch |
| | for key, value in result.items(): |
| | result[key] = torch.tensor(value, dtype=torch.long).unsqueeze(0) |
| | |
| | return result |
| |
|
| | def encode_old_model_style( |
| | self, |
| | text: str, |
| | seq_type: str = "gene", |
| | max_length: int = None |
| | ) -> List[int]: |
| | """ |
| | Encode using the EXACT same process as the old model's encoder function. |
| | This replicates the logic from src/llm/lucaone_virus/get_embedding.py:encoder() |
| | """ |
| | |
| | if seq_type == "gene": |
| | text = gene_seq_replace(text) |
| | |
| | |
| | seq_encoded = self.encode(text, seq_type=seq_type, add_special_tokens=False) |
| | |
| | |
| | if max_length and len(seq_encoded) > max_length: |
| | seq_encoded = seq_encoded[:max_length] |
| | |
| | |
| | processed_seq_len = len(seq_encoded) + int(self.prepend_bos) + int(self.append_eos) |
| | |
| | |
| | input_ids = [self.padding_idx] * processed_seq_len |
| | |
| | |
| | if self.prepend_bos: |
| | input_ids[0] = self.cls_idx |
| | |
| | |
| | start_idx = int(self.prepend_bos) |
| | for i, token_id in enumerate(seq_encoded): |
| | input_ids[start_idx + i] = token_id |
| | |
| | |
| | if self.append_eos: |
| | input_ids[len(seq_encoded) + int(self.prepend_bos)] = self.eos_idx |
| | |
| | return input_ids |
| |
|
| | def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: |
| | """ |
| | Save the tokenizer vocabulary to a JSON file. |
| | Required by HuggingFace tokenizer interface. |
| | """ |
| | if filename_prefix is None: |
| | filename_prefix = "" |
| | else: |
| | filename_prefix = filename_prefix + "-" |
| | |
| | vocab_file = os.path.join(save_directory, f"{filename_prefix}vocab.json") |
| | vocab_dict = self.get_vocab() |
| | with open(vocab_file, "w", encoding="utf-8") as f: |
| | json.dump(vocab_dict, f, ensure_ascii=False, indent=2) |
| | |
| | return (vocab_file,) |
| |
|
| | @classmethod |
| | def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs): |
| | """ |
| | Load tokenizer from pretrained model path (standard HuggingFace interface) |
| | """ |
| | vocab_file = os.path.join(pretrained_model_name_or_path, "vocab.json") |
| | if os.path.exists(vocab_file): |
| | print("Load from saved vocabulary (not implemented yet, use default)") |
| | return cls(vocab_type="gene_prot", **kwargs) |
| | else: |
| | return cls(vocab_type="gene_prot", **kwargs) |
| |
|
| | class LucaGPLMTokenizerFast(PreTrainedTokenizerFast): |
| | """ |
| | Fast tokenizer version - currently just delegates to slow tokenizer |
| | """ |
| | slow_tokenizer_class = LucaGPLMTokenizer |
| | |
| | def __init__(self, **kwargs): |
| | |
| | |
| | super().__init__(**kwargs) |
| |
|
| | __all__ = ["LucaGPLMTokenizer", "LucaGPLMTokenizerFast", "gene_seq_replace"] |